Fixed GetCLNodeTensors.
PiperOrigin-RevId: 267193520
This commit is contained in:
parent
059fe88ded
commit
daea9fbc16
@ -53,22 +53,21 @@ bool IsReady(const std::unordered_set<ValueId>& ready_tensors,
|
||||
std::vector<std::pair<ValueId, TensorDescriptor>> GetCLNodeTensors(
|
||||
const CLNode& node) {
|
||||
std::vector<std::pair<ValueId, TensorDescriptor>> result;
|
||||
for (int i = 0; i < node.operations.size(); ++i) {
|
||||
const OperationDef op_def = node.operations[i]->GetDefinition();
|
||||
const auto& first_range = node.ranges[0];
|
||||
for (int k = first_range.x; k < first_range.y; ++k) {
|
||||
result.push_back({node.inputs[k], op_def.src_tensors[k - first_range.x]});
|
||||
}
|
||||
for (int j = 1; j < node.ranges.size(); ++j) {
|
||||
const auto& range = node.ranges[j];
|
||||
for (int k = range.x; k < range.y; ++k) {
|
||||
result.push_back({node.inputs[k], op_def.src_tensors[k - range.x + 1]});
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < node.outputs.size(); ++j) {
|
||||
result.push_back({node.outputs[j], op_def.dst_tensors[j]});
|
||||
const OperationDef main_def = node.operations[0]->GetDefinition();
|
||||
const auto& first_range = node.ranges[0];
|
||||
for (int k = first_range.x; k < first_range.y; ++k) {
|
||||
result.push_back({node.inputs[k], main_def.src_tensors[k - first_range.x]});
|
||||
}
|
||||
for (int j = 1; j < node.ranges.size(); ++j) {
|
||||
const auto& range = node.ranges[j];
|
||||
const OperationDef op_def = node.operations[j]->GetDefinition();
|
||||
for (int k = range.x; k < range.y; ++k) {
|
||||
result.push_back({node.inputs[k], op_def.src_tensors[k - range.x + 1]});
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < node.outputs.size(); ++j) {
|
||||
result.push_back({node.outputs[j], main_def.dst_tensors[j]});
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user