Merging of nodes done as in OpenCL inference context.

PiperOrigin-RevId: 351217823
Change-Id: I00e795dc5bd0f2382e521799501ed23e2530a89a
This commit is contained in:
Raman Sarokin 2021-01-11 13:03:10 -08:00 committed by TensorFlower Gardener
parent ae83b4fa29
commit 6e2d5e6c55
3 changed files with 64 additions and 415 deletions

View File

@ -104,7 +104,6 @@ using ComputeTaskDescriptorPtr = std::shared_ptr<ComputeTaskDescriptor>;
struct NodeDescriptor {
ComputeTaskDescriptorPtr task;
// Unique ID to match the graph compilation errors.
int id;
std::string description;
std::vector<ValueId> src_tensors_ids;

View File

@ -38,6 +38,17 @@ namespace tflite {
namespace gpu {
namespace metal {
namespace {
bool IsReady(const std::set<ValueId>& ready_tensors,
const NodeDescriptor& node) {
for (const ValueId in_id : node.src_tensors_ids) {
if (ready_tensors.find(in_id) == ready_tensors.end()) {
return false;
}
}
return true;
}
void AddUsage(ValueId id, int task_index,
std::map<ValueId, int2>* usage_records) {
auto it = usage_records->find(id);
@ -82,346 +93,6 @@ bool IsGenericAdd(const Node& node, const std::vector<Value*>& inputs,
return true;
}
// Allows to get result about the graph compilation to validate graph. This
// information helps to find a cause of performance degradation, like misfusing.
struct OptimizationInfo {
// Initial operations count before compilation.
int operations_count;
// GPU tasks count after fusion and splitting complex operations into few GPU
// subtasks.
int gpu_tasks_count;
// Some operations are not used due to dependencies of the graph.
std::vector<int> unused_operations;
// Used inputs.
std::vector<ValueId> input_buffer_ids;
// Unused inputs. Requested outputs do not require this inputs to be used.
std::vector<ValueId> unused_input_buffer_ids;
// The outputs are deducted by the graph but not requested by user.
std::vector<ValueId> extra_output_buffer_ids;
// Outputs that are requested but can't be calculated by the graph.
std::vector<ValueId> missing_output_buffer_ids;
};
using FusionSequence = std::vector<NodeDescriptor>;
bool Contains(const std::vector<ValueId>& container, ValueId value) {
return std::find(container.begin(), container.end(), value) !=
container.end();
}
template <class T>
bool Contains(const std::vector<T>& container, ValueId value) {
for (const auto& buffer : container) {
if (buffer.id == value) {
return true;
}
}
return false;
}
// Checks if all elements of the narrow vector exist in the wide vector. Vectors
// are expected to be unsorted.
bool Contains(const std::vector<ValueId>& wide,
const std::vector<ValueId>& narrow) {
if (narrow.empty() || narrow.size() > wide.size()) {
return false;
}
std::set<ValueId> wide_sorted;
wide_sorted.insert(wide.begin(), wide.end());
for (auto element : narrow) {
if (std::find(wide.begin(), wide.end(), element) == wide.end()) {
return false;
}
}
return true;
}
uint32_t BufferUseCount(ValueId id,
const std::list<NodeDescriptor>& descriptors,
std::list<FusionSequence>* chains) {
uint32_t use_count = 0;
// Buffer may be read by both processed and not processed operations.
for (auto& desc : descriptors) {
if (Contains(desc.src_tensors_ids, id)) {
use_count++;
}
}
for (auto& chain : *chains) {
for (auto& desc : chain) {
if (Contains(desc.src_tensors_ids, id)) {
use_count++;
}
}
}
return use_count;
}
// Examines if the second operation can be linked to the first one. Linking may
// be skipped in the situation when conflict may happen: if first operation's
// output is used by more than 1 other operation.
bool CanFuseOperations(const NodeDescriptor& first,
const NodeDescriptor& second,
const std::vector<ValueId>& output_ids,
const std::list<NodeDescriptor>& descriptors,
std::list<FusionSequence>* chains) {
return second.task->is_linkable &&
!Contains(output_ids, first.dst_tensors_ids[0]) &&
BufferUseCount(first.dst_tensors_ids[0], descriptors, chains) == 1;
}
// Takes an unsorted list of task descriptors, builds a list of chains. Each
// chain is a list of task descriptors that can be fused into a single GPU task.
// Building is started from the input IDs and building statistic is filled.
void BuildFusableChains(const std::vector<ValueId>& input_ids,
const std::vector<ValueId>& output_ids,
std::list<NodeDescriptor>* descriptors,
std::list<FusionSequence>* chains,
std::vector<int>* unused_ids) {
// Proxy tasks for inputs - only output is valid on this elements.
for (auto input_id : input_ids) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->is_linkable = true;
desc->AddDstTensor("", {});
NodeDescriptor node;
node.task = desc;
node.dst_tensors_ids = {input_id};
chains->push_back({node});
}
if (descriptors->empty()) return;
// Get all possible operations - grow-up chains.
bool added;
do {
// At least one element must be added to any chain at this step.
added = false;
for (auto it = descriptors->begin(); it != descriptors->end();) {
const NodeDescriptor& task_descriptor = *it;
// Gather all outputs of all chains to check with.
std::vector<ValueId> ready_buffer_ids;
ready_buffer_ids.reserve(chains->size());
for (const auto& chain : *chains) {
ready_buffer_ids.push_back(chain.back().dst_tensors_ids[0]);
}
// Check if all inputs of this operation are ready.
if (Contains(ready_buffer_ids, task_descriptor.src_tensors_ids)) {
// Now find a chain to fuse with.
bool fused = false;
for (auto& chain : *chains) {
// We can fuse only single output for now.
const bool can_link = task_descriptor.src_tensors_ids[0] ==
chain.back().dst_tensors_ids[0];
if (can_link && CanFuseOperations(chain.back(), task_descriptor,
output_ids, *descriptors, chains)) {
chain.push_back(task_descriptor);
fused = true;
break;
}
}
if (!fused) {
chains->push_back({task_descriptor});
}
// Remove operation from original list and start from the beginning.
descriptors->erase(it);
added = true;
break;
} else {
++it;
}
}
} while (!descriptors->empty() && added);
unused_ids->reserve(descriptors->size());
for (const auto& desc : *descriptors) {
unused_ids->push_back(desc.id);
}
}
// Accepts unsorted list of chains and returns sorted list with the order of GPU
// task execution.
std::list<FusionSequence> SortChains(
const std::vector<ValueId>& graph_input_ids,
std::list<FusionSequence>* chains) {
std::list<FusionSequence> sorted_chains;
while (!chains->empty()) {
// Collect ready buffers.
std::vector<ValueId> ready_buffer_ids;
ready_buffer_ids.reserve(graph_input_ids.size() + sorted_chains.size());
ready_buffer_ids.insert(ready_buffer_ids.begin(), graph_input_ids.begin(),
graph_input_ids.end());
for (auto& chain : sorted_chains) {
ready_buffer_ids.push_back(chain.back().dst_tensors_ids[0]);
}
for (auto it = chains->begin(); it != chains->end();) {
const FusionSequence& chain = *it;
// If the input is also is the output in the same chain - eliminate
// because it used internally inside sthis chain only.
std::vector<ValueId> elements_output_buffer_ids;
elements_output_buffer_ids.reserve(chain.size());
for (const auto& element : chain) {
elements_output_buffer_ids.push_back(element.dst_tensors_ids[0]);
}
// Collect all inputs also for linked operations.
std::vector<ValueId> elements_input_buffer_ids;
for (const auto& element : chain) {
for (const auto& id : element.src_tensors_ids) {
if (!Contains(elements_output_buffer_ids, id)) {
elements_input_buffer_ids.push_back(id);
}
}
}
if (Contains(ready_buffer_ids, elements_input_buffer_ids)) {
// All input buffers for all elements of this chain are ready.
sorted_chains.push_back(chain);
it = chains->erase(it);
} else {
++it;
}
}
}
return sorted_chains;
}
// If a graph structure contains unused outputs then it can lead to unused
// operations and unused input buffers. It's not an error but some sort of
// warning.
std::vector<ValueId> GetUsedInputBufferIds(
const std::list<FusionSequence>& sorted_chains) {
// Match requested outputs with all outputs and intermediate buffers.
std::vector<ValueId> output_and_intermediate_ids;
output_and_intermediate_ids.reserve(sorted_chains.size());
std::set<ValueId> input_and_intermediate_ids;
for (auto it = sorted_chains.begin(); it != sorted_chains.end(); ++it) {
output_and_intermediate_ids.push_back(it->back().dst_tensors_ids[0]);
for (const auto& id : it->front().src_tensors_ids) {
input_and_intermediate_ids.insert(id);
}
}
std::vector<ValueId> input_ids;
for (ValueId id : input_and_intermediate_ids) {
if (!Contains(output_and_intermediate_ids, id)) {
input_ids.push_back(id);
}
}
return input_ids;
}
// If a buffer is requested as output from the graph but the graph structure
// can't provide this buffer by output (can't deduct), that means the graph
// structure is incorrect.
std::vector<ValueId> GetMissingOutputBufferIds(
const std::vector<ValueId>& output_ids,
const std::list<FusionSequence>& sorted_chains) {
// Match requested outputs with all output and intermediate buffers.
std::vector<ValueId> output_and_intermediate_ids;
output_and_intermediate_ids.reserve(sorted_chains.size());
for (auto it = sorted_chains.begin(); it != sorted_chains.end(); ++it) {
output_and_intermediate_ids.push_back(it->back().dst_tensors_ids[0]);
}
std::vector<ValueId> missing_output_ids;
for (ValueId id : output_ids) {
if (!Contains(output_and_intermediate_ids, id)) {
missing_output_ids.push_back(id);
}
}
return missing_output_ids;
}
// Graph may contain leafs with outputs that are not requested. It wastes GPU
// computations.
std::vector<ValueId> DeductOutputBufferIds(
const std::vector<ValueId>& output_ids,
const std::list<FusionSequence>& sorted_chains) {
std::vector<ValueId> extra_output_ids;
// Detect all unused output buffers - all outputs.
for (auto it1 = sorted_chains.begin(); it1 != sorted_chains.end(); ++it1) {
bool found_as_input = false;
for (auto it2 = sorted_chains.begin(); it2 != sorted_chains.end(); ++it2) {
if (it1 != it2) {
std::vector<ValueId> input_ids;
for (const auto& element : *it2) {
for (const auto& id : element.src_tensors_ids) {
input_ids.push_back(id);
}
}
if (Contains(input_ids, it1->back().dst_tensors_ids[0])) {
found_as_input = true;
break;
}
}
}
if (!found_as_input) {
if (!Contains(output_ids, it1->back().dst_tensors_ids[0])) {
extra_output_ids.push_back(it1->back().dst_tensors_ids[0]);
}
}
}
return extra_output_ids;
}
// Delete all unused task descriptors that have non-requested outputs.
// !delete not the whole chain but only the last element, then others.!
std::vector<int> DeleteUnusedTasks(const std::vector<ValueId>& output_ids,
std::list<FusionSequence>* chains) {
std::vector<int> unused_operations;
for (auto it1 = chains->rbegin(); it1 != chains->rend();) {
// Don't delete if output is requested.
if (Contains(output_ids, it1->back().dst_tensors_ids[0])) {
++it1;
continue;
}
// Don't delete if some operation uses the output.
bool output_used = false;
for (auto it2 = chains->rbegin(); it2 != chains->rend(); ++it2) {
std::vector<ValueId> input_ids;
for (const auto& element : *it2) {
for (const auto& id : element.src_tensors_ids) {
input_ids.push_back(id);
}
}
if (Contains(input_ids, it1->back().dst_tensors_ids[0])) {
output_used = true;
break;
}
}
if (output_used) {
++it1;
continue;
}
// Delete if not used.
unused_operations.push_back(it1->back().id);
it1 = decltype(it1){chains->erase(std::next(it1).base())};
}
return unused_operations;
}
// Returns unused input buffer IDs.
void RemoveInputProxies(std::list<FusionSequence>* chains) {
// Remove input proxy and sort items.
for (auto it = chains->begin(); it != chains->end();) {
auto& chain = *it;
// Remove input proxy-operations.
if (chain.front().src_tensors_ids.empty()) {
chain.erase(chain.begin());
}
if (chain.empty()) {
// Input proxy operation has been deleted and the chain is empty due to
// unused input buffer.
it = chains->erase(it);
} else {
++it;
}
}
}
absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst) {
for (int j = 1; j < src->src_tensors_ids.size(); ++j) {
dst->src_tensors_ids.push_back(src->src_tensors_ids[j]);
@ -430,14 +101,6 @@ absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst) {
dst->description += " linked : " + src->description;
return dst->task->AddTask(src->task.get());
}
absl::Status FuseChain(const FusionSequence& chain, NodeDescriptor* node_desc) {
*node_desc = chain.front();
for (int j = 1; j < chain.size(); ++j) {
RETURN_IF_ERROR(MergeNodes(&chain[j], node_desc));
}
return absl::OkStatus();
}
} // namespace
absl::Status InferenceContext::InitFromGraph(
@ -459,11 +122,8 @@ absl::Status InferenceContext::InitFromGraph(
CompiledModel compiled_model;
RETURN_IF_ERROR(
Compile(graph, gpu_info, create_info.precision, &compiled_model));
CompiledModel optimized_model;
RETURN_IF_ERROR(ValidateOptimizeModel(input_ids_, output_ids_, compiled_model,
&optimized_model));
RETURN_IF_ERROR(CompileModelWithDevice(device, optimized_model, input_ids_,
RETURN_IF_ERROR(Merge(&compiled_model));
RETURN_IF_ERROR(CompileModelWithDevice(device, compiled_model, input_ids_,
output_ids_, create_info.precision));
return absl::OkStatus();
}
@ -500,18 +160,17 @@ absl::Status InferenceContext::Compile(const GraphFloat32& graph,
// associativity and ADD can be linked. In current approach "linking"
// tensor can be only latest written tensor(during linear order of
// execution) among input tensors.
// TODO(b/176397043) sorokin, check failure on segmentation model
// if (IsGenericAdd(node, inputs, outputs)) {
// int latest_written_tensor_index = 0;
// int last_usage = tensor_usages[inputs[0]->id];
// for (int j = 1; j < inputs.size(); ++j) {
// if (tensor_usages[inputs[j]->id] > last_usage) {
// last_usage = tensor_usages[inputs[j]->id];
// latest_written_tensor_index = j;
// }
// }
// std::swap(inputs[0], inputs[latest_written_tensor_index]);
// }
if (IsGenericAdd(node, inputs, outputs)) {
int latest_written_tensor_index = 0;
int last_usage = tensor_usages[inputs[0]->id];
for (int j = 1; j < inputs.size(); ++j) {
if (tensor_usages[inputs[j]->id] > last_usage) {
last_usage = tensor_usages[inputs[j]->id];
latest_written_tensor_index = j;
}
}
std::swap(inputs[0], inputs[latest_written_tensor_index]);
}
DataType data_type = DeduceDataTypeFromPrecision(precision);
TensorDescriptor tensor_descriptor =
TensorDescriptor{data_type, TensorStorageType::BUFFER, Layout::HWC};
@ -557,63 +216,53 @@ absl::Status InferenceContext::Compile(const GraphFloat32& graph,
}
metal_node.description =
node.operation.type + " " + std::to_string(node.id);
metal_node.id = i;
compiled_model->nodes.push_back(std::move(metal_node));
}
}
return absl::OkStatus();
}
absl::Status InferenceContext::ValidateOptimizeModel(
const std::vector<ValueId>& input_buffers,
const std::vector<ValueId>& output_buffers,
const CompiledModel& input_model, CompiledModel* output_model) {
std::list<NodeDescriptor> input;
input.insert(input.end(), input_model.nodes.begin(), input_model.nodes.end());
OptimizationInfo info;
info.operations_count = static_cast<int>(input.size());
// A chain is a sequence of fusable operations. All internal outputs are
// consumed with the next element of the chain. The last element of each chain
// contains outputs which are ready to be used as inputs. if a chain can't be
// extended with linkable element then new chain is created.
std::list<FusionSequence> unsorted_chains;
BuildFusableChains(input_buffers, output_buffers, &input, &unsorted_chains,
&info.unused_operations);
RemoveInputProxies(&unsorted_chains);
std::list<FusionSequence> sorted_chains =
SortChains(input_buffers, &unsorted_chains);
info.extra_output_buffer_ids =
DeductOutputBufferIds(output_buffers, sorted_chains);
info.unused_operations = DeleteUnusedTasks(output_buffers, &sorted_chains);
info.input_buffer_ids = GetUsedInputBufferIds(sorted_chains);
// find provided input buffers that has not being used
for (ValueId id : input_buffers) {
if (!Contains(info.input_buffer_ids, id)) {
info.unused_input_buffer_ids.push_back(id);
absl::Status InferenceContext::Merge(CompiledModel* model) {
std::set<ValueId> ready_tensors;
for (const auto& input_id : input_ids_) {
ready_tensors.insert(input_id);
}
for (int i = 0; i < model->nodes.size(); ++i) {
auto& node = model->nodes[i];
for (const auto& out_id : node.dst_tensors_ids) {
ready_tensors.insert(out_id);
}
if (node.dst_tensors_ids.size() != 1) {
continue;
}
std::vector<int> next_nodes;
int link_index = 0;
for (int j = i + 1; j < model->nodes.size(); ++j) {
for (int k = 0; k < model->nodes[j].src_tensors_ids.size(); ++k) {
if (model->nodes[j].src_tensors_ids[k] == node.dst_tensors_ids[0]) {
next_nodes.push_back(j);
link_index = k;
}
}
}
if (next_nodes.size() != 1 || link_index != 0) {
continue;
}
auto& linkable_node = model->nodes[next_nodes[0]];
if (!linkable_node.task->is_linkable ||
linkable_node.dst_tensors_ids.size() != 1 ||
!IsReady(ready_tensors, linkable_node)) {
continue;
}
const auto& original_dst_def = node.task->definition.dst_tensors[0];
const auto& link_dst_def = linkable_node.task->definition.dst_tensors[0];
if (original_dst_def != link_dst_def) {
continue;
}
RETURN_IF_ERROR(MergeNodes(&linkable_node, &node));
model->nodes.erase(model->nodes.begin() + next_nodes[0]);
i -= 1;
}
info.missing_output_buffer_ids =
GetMissingOutputBufferIds(output_buffers, sorted_chains);
info.gpu_tasks_count = static_cast<int>(sorted_chains.size());
if (sorted_chains.empty()) {
const std::string message =
"No valid operations in the graph.\nInput operations count " +
std::to_string(info.operations_count) + "\nUnused operations " +
std::to_string(info.unused_operations.size()) + "\nUnused inputs " +
std::to_string(info.unused_input_buffer_ids.size()) +
"\nMissing output buffers " +
std::to_string(info.missing_output_buffer_ids.size());
return absl::InternalError(message);
}
for (const auto& chain : sorted_chains) {
NodeDescriptor fused_node;
RETURN_IF_ERROR(FuseChain(chain, &fused_node));
output_model->nodes.push_back(std::move(fused_node));
}
output_model->tensor_shapes = input_model.tensor_shapes;
return absl::OkStatus();
}

View File

@ -106,6 +106,7 @@ class InferenceContext {
const std::vector<ValueId>& output_ids,
CalculationsPrecision precision);
absl::Status Merge(CompiledModel* model);
absl::Status AllocateTensors(id<MTLDevice> device);
absl::Status AllocateMemoryForBuffers(id<MTLDevice> device);
void BindTensorsToOperations();