Convert template classes to simple classes:
* Value<TensorRef<BHWC>> is now Value. * Graph<TensorRef<BHWC>> -> Model<TensorRef<BHWC>> -> typedef GraphFloat32 is now a single class GraphFloat32. PiperOrigin-RevId: 306926255 Change-Id: I6730c1a999504c7484aec16ecf4b6aa213607c70
This commit is contained in:
parent
339eab07e8
commit
22350e7ca8
@ -661,9 +661,8 @@ class InferenceBuilderImpl : public InferenceBuilder {
|
||||
}
|
||||
|
||||
// Links internal tensors with external user-facing objects.
|
||||
std::vector<TensorTieDef> LinkTensors(
|
||||
const GraphFloat32& graph,
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& values) {
|
||||
std::vector<TensorTieDef> LinkTensors(const GraphFloat32& graph,
|
||||
const std::vector<Value*>& values) {
|
||||
std::vector<TensorTieDef> links;
|
||||
links.reserve(values.size());
|
||||
for (const auto& value : values) {
|
||||
|
@ -119,9 +119,8 @@ bool IsBufferBased(const TensorStorageType& type) {
|
||||
|
||||
// Generic add is add that have several runtime inputs and they are not
|
||||
// broadcasted, i.e. pointwise add for N tensors where N > 1.
|
||||
bool IsGenericAdd(const Node& node,
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs,
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& outputs) {
|
||||
bool IsGenericAdd(const Node& node, const std::vector<Value*>& inputs,
|
||||
const std::vector<Value*>& outputs) {
|
||||
if (inputs.size() == 1) {
|
||||
return false;
|
||||
}
|
||||
|
@ -30,9 +30,8 @@ namespace cl {
|
||||
|
||||
absl::Status SelectDefault(const CreationContext& creation_context,
|
||||
const OperationDef& op_def, ModelHints hints,
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs,
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& outputs,
|
||||
const Node& node,
|
||||
const std::vector<Value*>& inputs,
|
||||
const std::vector<Value*>& outputs, const Node& node,
|
||||
GPUOperationsSubgraph* gpu_subgraph) {
|
||||
return absl::UnimplementedError(
|
||||
absl::StrCat("No selector for ", node.operation.type));
|
||||
|
@ -31,9 +31,8 @@ namespace cl {
|
||||
|
||||
absl::Status SelectDefault(const CreationContext& creation_context,
|
||||
const OperationDef& op_def, ModelHints hints,
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs,
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& outputs,
|
||||
const Node& node,
|
||||
const std::vector<Value*>& inputs,
|
||||
const std::vector<Value*>& outputs, const Node& node,
|
||||
GPUOperationsSubgraph* gpu_subgraph);
|
||||
|
||||
} // namespace cl
|
||||
|
@ -37,20 +37,17 @@ namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
bool IsWidthBroadcastedForSecondInput(
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
|
||||
bool IsWidthBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
|
||||
return inputs.size() == 2 &&
|
||||
inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w &&
|
||||
inputs[1]->tensor.shape.w == 1;
|
||||
}
|
||||
bool IsHeightBroadcastedForSecondInput(
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
|
||||
bool IsHeightBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
|
||||
return inputs.size() == 2 &&
|
||||
inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h &&
|
||||
inputs[1]->tensor.shape.h == 1;
|
||||
}
|
||||
bool IsChannelsBroadcastedForSecondInput(
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
|
||||
bool IsChannelsBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
|
||||
return inputs.size() == 2 &&
|
||||
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c &&
|
||||
inputs[1]->tensor.shape.c == 1;
|
||||
@ -146,11 +143,12 @@ absl::Status WinogradFromNode(const CreationContext& creation_context,
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::Status GPUOperationFromNode(
|
||||
const CreationContext& creation_context, const OperationDef& op_def,
|
||||
ModelHints hints, const std::vector<Value<TensorRef<BHWC>>*>& inputs,
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& outputs, const Node& node,
|
||||
GPUOperationsSubgraph* gpu_subgraph) {
|
||||
absl::Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
const OperationDef& op_def, ModelHints hints,
|
||||
const std::vector<Value*>& inputs,
|
||||
const std::vector<Value*>& outputs,
|
||||
const Node& node,
|
||||
GPUOperationsSubgraph* gpu_subgraph) {
|
||||
std::unique_ptr<GPUOperation>* gpu_op =
|
||||
InitSingleOpSubgraph(inputs, outputs, gpu_subgraph);
|
||||
auto op_type = OperationTypeFromString(node.operation.type);
|
||||
|
@ -29,11 +29,12 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
absl::Status GPUOperationFromNode(
|
||||
const CreationContext& creation_context, const OperationDef& op_def,
|
||||
ModelHints hints, const std::vector<Value<TensorRef<BHWC>>*>& inputs,
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& outputs, const Node& node,
|
||||
GPUOperationsSubgraph* gpu_subgraph);
|
||||
absl::Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
const OperationDef& op_def, ModelHints hints,
|
||||
const std::vector<Value*>& inputs,
|
||||
const std::vector<Value*>& outputs,
|
||||
const Node& node,
|
||||
GPUOperationsSubgraph* gpu_subgraph);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -26,8 +26,7 @@ namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
std::unique_ptr<GPUOperation>* InitSingleOpSubgraph(
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs,
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& outputs,
|
||||
const std::vector<Value*>& inputs, const std::vector<Value*>& outputs,
|
||||
GPUOperationsSubgraph* gpu_subgraph) {
|
||||
gpu_subgraph->operations.clear();
|
||||
gpu_subgraph->new_tensors.clear();
|
||||
|
@ -43,8 +43,7 @@ struct GPUOperationsSubgraph {
|
||||
};
|
||||
|
||||
std::unique_ptr<GPUOperation>* InitSingleOpSubgraph(
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs,
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& outputs,
|
||||
const std::vector<Value*>& inputs, const std::vector<Value*>& outputs,
|
||||
GPUOperationsSubgraph* gpu_subgraph);
|
||||
|
||||
} // namespace cl
|
||||
|
@ -78,6 +78,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "model",
|
||||
srcs = ["model.cc"],
|
||||
hdrs = ["model.h"],
|
||||
deps = [
|
||||
":data_type",
|
||||
|
451
tensorflow/lite/delegates/gpu/common/model.cc
Normal file
451
tensorflow/lite/delegates/gpu/common/model.cc
Normal file
@ -0,0 +1,451 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
std::vector<Node*> GraphFloat32::nodes() const {
|
||||
return FilterNodes([](const NodeDef&) { return true; });
|
||||
}
|
||||
|
||||
std::vector<Value*> GraphFloat32::values() const {
|
||||
return FilterValues([](const ValueDef&) { return true; });
|
||||
}
|
||||
|
||||
std::vector<Value*> GraphFloat32::inputs() const {
|
||||
return FilterValues([](const ValueDef& v) { return v.producer == nullptr; });
|
||||
}
|
||||
|
||||
std::vector<Value*> GraphFloat32::outputs() const {
|
||||
return FilterValues([](const ValueDef& v) { return v.consumers.empty(); });
|
||||
}
|
||||
|
||||
std::vector<Value*> GraphFloat32::FindInputs(NodeId id) const {
|
||||
if (id >= nodes_.size()) {
|
||||
return {};
|
||||
}
|
||||
return nodes_.at(id).inputs;
|
||||
}
|
||||
|
||||
std::vector<Value*> GraphFloat32::FindOutputs(NodeId id) const {
|
||||
if (id >= nodes_.size()) {
|
||||
return {};
|
||||
}
|
||||
return nodes_.at(id).outputs;
|
||||
}
|
||||
|
||||
bool GraphFloat32::IsGraphInput(ValueId id) const {
|
||||
if (id >= values_.size()) {
|
||||
return false;
|
||||
}
|
||||
return values_[id].producer == nullptr;
|
||||
}
|
||||
|
||||
bool GraphFloat32::IsGraphOutput(ValueId id) const {
|
||||
if (id >= values_.size()) {
|
||||
return false;
|
||||
}
|
||||
return values_[id].consumers.empty();
|
||||
}
|
||||
|
||||
Node* GraphFloat32::FindProducer(ValueId id) const {
|
||||
if (id >= values_.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
return values_[id].producer;
|
||||
}
|
||||
|
||||
std::vector<Node*> GraphFloat32::FindConsumers(ValueId id) const {
|
||||
if (id >= values_.size()) {
|
||||
return {};
|
||||
}
|
||||
return values_[id].consumers;
|
||||
}
|
||||
|
||||
Node* GraphFloat32::GetNode(NodeId id) const {
|
||||
if (id >= nodes_.size()) {
|
||||
return {};
|
||||
}
|
||||
return nodes_.at(id).node.get();
|
||||
}
|
||||
|
||||
Value* GraphFloat32::GetValue(ValueId id) const {
|
||||
if (id >= values_.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
return values_[id].value.get();
|
||||
}
|
||||
|
||||
Node* GraphFloat32::NewNode() {
|
||||
const NodeId new_id = nodes_.size();
|
||||
NodeDef def;
|
||||
def.node = absl::make_unique<Node>(Node{static_cast<NodeId>(new_id), {}});
|
||||
Node* node = def.node.get();
|
||||
nodes_[new_id] = std::move(def);
|
||||
execution_plan_.push_back(new_id);
|
||||
return node;
|
||||
}
|
||||
|
||||
absl::Status GraphFloat32::InsertNodeAfter(NodeId id, Node** new_node) {
|
||||
if (id >= nodes_.size()) {
|
||||
return absl::OutOfRangeError("NodeId is out of range");
|
||||
}
|
||||
int idx = 0;
|
||||
while (idx < execution_plan_.size()) {
|
||||
if (execution_plan_[idx] == id) break;
|
||||
++idx;
|
||||
}
|
||||
if (idx == execution_plan_.size()) {
|
||||
return absl::OutOfRangeError("NodeId not in execution plan");
|
||||
}
|
||||
|
||||
const NodeId new_id = nodes_.size();
|
||||
NodeDef def;
|
||||
def.node = absl::make_unique<Node>(Node{static_cast<NodeId>(new_id), {}});
|
||||
*new_node = def.node.get();
|
||||
nodes_[new_id] = std::move(def);
|
||||
execution_plan_.insert(execution_plan_.begin() + idx + 1, new_id);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Value* GraphFloat32::NewValue() {
|
||||
ValueDef def;
|
||||
def.value =
|
||||
absl::make_unique<Value>(Value{static_cast<ValueId>(values_.size()), {}});
|
||||
Value* value = def.value.get();
|
||||
values_.push_back(std::move(def));
|
||||
return value;
|
||||
}
|
||||
|
||||
absl::Status GraphFloat32::SetProducer(NodeId producer, ValueId value) {
|
||||
ValueDef* v;
|
||||
RETURN_IF_ERROR(LookupValue(value, &v));
|
||||
Value* value_ptr = v->value.get();
|
||||
NodeDef* n;
|
||||
RETURN_IF_ERROR(LookupNode(producer, &n));
|
||||
Node* node_ptr = n->node.get();
|
||||
|
||||
// check if this value has the same producer already
|
||||
if (node_ptr == v->producer) {
|
||||
return absl::AlreadyExistsError(absl::StrCat(
|
||||
"Node ", producer, " is already a producer of the value ", value));
|
||||
}
|
||||
|
||||
// Check if the node is a consumer of this value.
|
||||
if (IsInput(producer, value)) {
|
||||
return absl::InvalidArgumentError("Node is a consumer of the value");
|
||||
}
|
||||
|
||||
if (v->producer != nullptr) {
|
||||
// value is no longer produced by it's previous producer.
|
||||
Erase(&nodes_[v->producer->id].outputs, value_ptr);
|
||||
}
|
||||
v->producer = node_ptr;
|
||||
n->outputs.push_back(value_ptr);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status GraphFloat32::RemoveProducer(ValueId value) {
|
||||
ValueDef* v;
|
||||
RETURN_IF_ERROR(LookupValue(value, &v));
|
||||
Value* value_ptr = v->value.get();
|
||||
if (v->producer == nullptr) {
|
||||
return absl::InvalidArgumentError("Value does not have a producer");
|
||||
}
|
||||
Erase(&nodes_[v->producer->id].outputs, value_ptr);
|
||||
v->producer = nullptr;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status GraphFloat32::AddConsumer(NodeId consumer, ValueId value) {
|
||||
ValueDef* v;
|
||||
RETURN_IF_ERROR(LookupValue(value, &v));
|
||||
Value* value_ptr = v->value.get();
|
||||
NodeDef* n;
|
||||
RETURN_IF_ERROR(LookupNode(consumer, &n));
|
||||
Node* node_ptr = n->node.get();
|
||||
|
||||
// check if this value has the same producer already
|
||||
if (node_ptr == v->producer) {
|
||||
return absl::InvalidArgumentError("Node is a producer of the value");
|
||||
}
|
||||
|
||||
// check if this value has the same consumer already
|
||||
if (IsInput(consumer, value)) {
|
||||
return absl::AlreadyExistsError(absl::StrCat(
|
||||
"Node ", consumer, " is already a consumer of the value ", value));
|
||||
}
|
||||
|
||||
n->inputs.push_back(value_ptr);
|
||||
v->consumers.push_back(node_ptr);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Replace input value for given node.
|
||||
absl::Status GraphFloat32::ReplaceInput(NodeId node, ValueId old_value,
|
||||
ValueId new_value) {
|
||||
ValueDef* v_old;
|
||||
RETURN_IF_ERROR(LookupValue(old_value, &v_old));
|
||||
Value* value_old_ptr = v_old->value.get();
|
||||
ValueDef* v_new;
|
||||
RETURN_IF_ERROR(LookupValue(new_value, &v_new));
|
||||
Value* value_new_ptr = v_new->value.get();
|
||||
NodeDef* n;
|
||||
RETURN_IF_ERROR(LookupNode(node, &n));
|
||||
Node* node_ptr = n->node.get();
|
||||
|
||||
// Check if the node is a consumer of old_value.
|
||||
if (!IsInput(node, old_value)) {
|
||||
return absl::InvalidArgumentError("old_value must be input of node.");
|
||||
}
|
||||
|
||||
// Check if the node is not a consumer of new_value.
|
||||
if (IsInput(node, new_value)) {
|
||||
return absl::InvalidArgumentError("new_value can not be input of node.");
|
||||
}
|
||||
|
||||
// Check if this value has the same producer already
|
||||
if (node_ptr == v_new->producer) {
|
||||
return absl::InvalidArgumentError("new_value can not be output of node.");
|
||||
}
|
||||
|
||||
for (int i = 0; i < n->inputs.size(); ++i) {
|
||||
if (n->inputs[i] == value_old_ptr) {
|
||||
n->inputs[i] = value_new_ptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
v_new->consumers.push_back(node_ptr);
|
||||
Erase(&v_old->consumers, node_ptr);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status GraphFloat32::RemoveConsumer(NodeId consumer, ValueId value) {
|
||||
ValueDef* v;
|
||||
RETURN_IF_ERROR(LookupValue(value, &v));
|
||||
Value* value_ptr = v->value.get();
|
||||
NodeDef* n;
|
||||
RETURN_IF_ERROR(LookupNode(consumer, &n));
|
||||
Node* node_ptr = n->node.get();
|
||||
if (!IsInput(consumer, value)) {
|
||||
return absl::InvalidArgumentError("Node is not a consumer of the value");
|
||||
}
|
||||
Erase(&n->inputs, value_ptr);
|
||||
Erase(&v->consumers, node_ptr);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status GraphFloat32::DeleteNode(NodeId id) {
|
||||
NodeDef* n;
|
||||
RETURN_IF_ERROR(LookupNode(id, &n));
|
||||
Node* node_ptr = n->node.get();
|
||||
for (auto value : n->inputs) {
|
||||
Erase(&values_[value->id].consumers, node_ptr);
|
||||
}
|
||||
for (auto value : n->outputs) {
|
||||
values_[value->id].producer = nullptr;
|
||||
}
|
||||
n->inputs.clear();
|
||||
n->outputs.clear();
|
||||
n->node.reset();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status GraphFloat32::DeleteValue(ValueId id) {
|
||||
ValueDef* v;
|
||||
RETURN_IF_ERROR(LookupValue(id, &v));
|
||||
Value* value_ptr = v->value.get();
|
||||
if (v->producer != nullptr) {
|
||||
Erase(&nodes_[v->producer->id].outputs, value_ptr);
|
||||
}
|
||||
if (!v->consumers.empty()) {
|
||||
for (auto node : v->consumers) {
|
||||
Erase(&nodes_[node->id].inputs, value_ptr);
|
||||
}
|
||||
}
|
||||
v->producer = nullptr;
|
||||
v->consumers.clear();
|
||||
v->value.reset();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status GraphFloat32::MakeExactCopy(GraphFloat32* model) const {
|
||||
model->nodes_.clear();
|
||||
model->execution_plan_.clear();
|
||||
model->values_.clear();
|
||||
for (auto& value_def : values_) {
|
||||
model->values_.push_back({});
|
||||
if (value_def.value) {
|
||||
model->values_.back().value = absl::make_unique<Value>(*value_def.value);
|
||||
}
|
||||
}
|
||||
// Add all nodes first.
|
||||
for (auto node_id : execution_plan_) {
|
||||
model->execution_plan_.push_back(node_id);
|
||||
model->nodes_[node_id] = {};
|
||||
auto& node_def = nodes_.at(node_id);
|
||||
if (node_def.node) {
|
||||
model->nodes_[node_id].node = absl::make_unique<Node>(*node_def.node);
|
||||
}
|
||||
}
|
||||
// Wire up dependencies between nodes.
|
||||
for (auto node_id : execution_plan_) {
|
||||
auto& node_def = nodes_.at(node_id);
|
||||
if (node_def.node) {
|
||||
for (auto output : node_def.outputs) {
|
||||
RETURN_IF_ERROR(model->SetProducer(node_def.node->id, output->id));
|
||||
}
|
||||
for (auto input : node_def.inputs) {
|
||||
RETURN_IF_ERROR(model->AddConsumer(node_def.node->id, input->id));
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
bool GraphFloat32::IsInput(NodeId node, ValueId value) {
|
||||
if (node >= nodes_.size() || value >= values_.size()) {
|
||||
return false;
|
||||
}
|
||||
const NodeDef& n = nodes_[node];
|
||||
const ValueDef& v = values_[value];
|
||||
if (!n.node || !v.value) {
|
||||
return false;
|
||||
}
|
||||
return std::find(n.inputs.begin(), n.inputs.end(), v.value.get()) !=
|
||||
n.inputs.end();
|
||||
}
|
||||
|
||||
absl::Status GraphFloat32::LookupNode(NodeId id, NodeDef** node_def) {
|
||||
if (id >= nodes_.size()) {
|
||||
return absl::OutOfRangeError("NodeId is out of range");
|
||||
}
|
||||
auto& n = nodes_[id];
|
||||
if (!n.node) {
|
||||
return absl::OutOfRangeError("Node is already deleted");
|
||||
}
|
||||
*node_def = &n;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status GraphFloat32::LookupValue(ValueId id, ValueDef** value_def) {
|
||||
if (id >= values_.size()) {
|
||||
return absl::OutOfRangeError("ValueId is out of range");
|
||||
}
|
||||
auto& v = values_[id];
|
||||
if (!v.value) {
|
||||
return absl::OutOfRangeError("Value is already deleted");
|
||||
}
|
||||
*value_def = &v;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status RemovePrecedingNode(GraphFloat32* graph, const Node* to_remove,
|
||||
const Node* to_keep) {
|
||||
// Make sure all outputs from to_remove are consumed by to_keep.
|
||||
for (auto output : graph->FindOutputs(to_remove->id)) {
|
||||
auto consumers = graph->FindConsumers(output->id);
|
||||
if (consumers.size() > 1 ||
|
||||
(consumers.size() == 1 && consumers[0] != to_keep)) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Output from to_remove node has other consumers");
|
||||
}
|
||||
}
|
||||
|
||||
// Update all references
|
||||
for (auto input : graph->FindInputs(to_remove->id)) {
|
||||
RETURN_IF_ERROR(graph->AddConsumer(to_keep->id, input->id));
|
||||
}
|
||||
for (auto output : graph->FindOutputs(to_remove->id)) {
|
||||
RETURN_IF_ERROR(graph->DeleteValue(output->id));
|
||||
}
|
||||
return graph->DeleteNode(to_remove->id);
|
||||
}
|
||||
|
||||
absl::Status RemoveFollowingNode(GraphFloat32* graph, const Node* to_remove,
|
||||
const Node* to_keep) {
|
||||
// Make sure all inputs to to_remove are produced by to_keep.
|
||||
for (auto input : graph->FindInputs(to_remove->id)) {
|
||||
Node* producer = graph->FindProducer(input->id);
|
||||
if (producer->id != to_keep->id) {
|
||||
return absl::InvalidArgumentError("To_remove node has other inputs");
|
||||
}
|
||||
}
|
||||
|
||||
for (auto input : graph->FindInputs(to_remove->id)) {
|
||||
RETURN_IF_ERROR(graph->DeleteValue(input->id));
|
||||
}
|
||||
for (auto output : graph->FindOutputs(to_remove->id)) {
|
||||
RETURN_IF_ERROR(graph->SetProducer(to_keep->id, output->id));
|
||||
}
|
||||
return graph->DeleteNode(to_remove->id);
|
||||
}
|
||||
|
||||
absl::Status RemoveOneInputOneOutputNode(GraphFloat32* graph,
|
||||
const Node* to_remove) {
|
||||
auto inputs = graph->FindInputs(to_remove->id);
|
||||
auto outputs = graph->FindOutputs(to_remove->id);
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"To_remove node must have 1 input and 1 output");
|
||||
}
|
||||
auto input_id = inputs[0]->id;
|
||||
auto output_id = outputs[0]->id;
|
||||
Node* producer = graph->FindProducer(input_id);
|
||||
auto consumers = graph->FindConsumers(output_id);
|
||||
RETURN_IF_ERROR(graph->DeleteNode(to_remove->id));
|
||||
for (auto& consumer : consumers) {
|
||||
RETURN_IF_ERROR(graph->ReplaceInput(consumer->id, output_id, input_id));
|
||||
}
|
||||
RETURN_IF_ERROR(graph->DeleteValue(output_id));
|
||||
if (!producer && consumers.empty()) {
|
||||
RETURN_IF_ERROR(graph->DeleteValue(input_id));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AddOutput(GraphFloat32* graph, const Node* from_node,
|
||||
Value** output) {
|
||||
auto link = graph->NewValue();
|
||||
RETURN_IF_ERROR(graph->SetProducer(from_node->id, link->id));
|
||||
*output = link;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node,
|
||||
const Node* to_node, Value** output) {
|
||||
Value* link;
|
||||
RETURN_IF_ERROR(AddOutput(graph, from_node, &link));
|
||||
RETURN_IF_ERROR(graph->AddConsumer(to_node->id, link->id));
|
||||
*output = link;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
bool IsBatchMatchesForAllValues(const GraphFloat32& model) {
|
||||
const int32_t b = model.values()[0]->tensor.shape.b;
|
||||
for (auto value : model.values()) {
|
||||
if (value->tensor.shape.b != b) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -50,77 +50,70 @@ struct QuantizationParams {
|
||||
};
|
||||
|
||||
// Connects tensor's producer and operation that depends on this tensor.
|
||||
template <typename TensorT>
|
||||
struct Value {
|
||||
using TensorType = TensorT;
|
||||
|
||||
const ValueId id;
|
||||
|
||||
TensorType tensor;
|
||||
|
||||
TensorRef<BHWC> tensor;
|
||||
absl::optional<QuantizationParams> quant_params;
|
||||
};
|
||||
|
||||
struct Operation {
|
||||
std::string type;
|
||||
|
||||
absl::any attributes;
|
||||
};
|
||||
|
||||
struct Node {
|
||||
const NodeId id;
|
||||
|
||||
Operation operation;
|
||||
};
|
||||
|
||||
// Graph is DAG that consists of nodes and values. Each value may have a single
|
||||
// A DAG that consists of nodes and values. Each value may have a single
|
||||
// producer node and multiple consumer nodes. Therefore, each node may have
|
||||
// multiple input and output values.
|
||||
//
|
||||
// Value that does not have a producer is a graph's input. Value that does not
|
||||
// have a consumer is a graph's output.
|
||||
//
|
||||
// Interface provides methods for graph introspection and manipulation. Abstract
|
||||
// interface makes allows subgraphs representation to ensure safe manipulations.
|
||||
template <typename TensorT>
|
||||
class Graph {
|
||||
// It keeps values and nodes referenced by their index in a vector. Therefore,
|
||||
// nodes and values are never deleted, but rather erased, where corresponding
|
||||
// index remains.
|
||||
//
|
||||
// It is possible to re-use removed indices, but it is not implemented yet.
|
||||
class GraphFloat32 {
|
||||
public:
|
||||
virtual ~Graph() = default;
|
||||
|
||||
// @return a collection of nodes in this graph.
|
||||
virtual std::vector<Node*> nodes() const = 0;
|
||||
std::vector<Node*> nodes() const;
|
||||
|
||||
// @return a collection of values in this graph.
|
||||
virtual std::vector<Value<TensorT>*> values() const = 0;
|
||||
std::vector<Value*> values() const;
|
||||
|
||||
// @return graph inputs, that are values without producers.
|
||||
virtual std::vector<Value<TensorT>*> inputs() const = 0;
|
||||
std::vector<Value*> inputs() const;
|
||||
|
||||
// @return graph outputs, that are values without consumers.
|
||||
virtual std::vector<Value<TensorT>*> outputs() const = 0;
|
||||
std::vector<Value*> outputs() const;
|
||||
|
||||
// @return inputs into the given node. Returns empty vector for deleted node.
|
||||
virtual std::vector<Value<TensorT>*> FindInputs(NodeId id) const = 0;
|
||||
std::vector<Value*> FindInputs(NodeId id) const;
|
||||
|
||||
// @return outputs from the given node. Returns empty vector for deleted node.
|
||||
virtual std::vector<Value<TensorT>*> FindOutputs(NodeId id) const = 0;
|
||||
std::vector<Value*> FindOutputs(NodeId id) const;
|
||||
|
||||
virtual bool IsGraphInput(ValueId id) const = 0;
|
||||
bool IsGraphInput(ValueId id) const;
|
||||
|
||||
virtual bool IsGraphOutput(ValueId id) const = 0;
|
||||
bool IsGraphOutput(ValueId id) const;
|
||||
|
||||
// @return producer of the given value. Returns nullptr for deleted value.
|
||||
virtual Node* FindProducer(ValueId id) const = 0;
|
||||
Node* FindProducer(ValueId id) const;
|
||||
|
||||
// @return consumers of the given value. Returns empty vector for deleted
|
||||
// value.
|
||||
virtual std::vector<Node*> FindConsumers(ValueId id) const = 0;
|
||||
std::vector<Node*> FindConsumers(ValueId id) const;
|
||||
|
||||
// @return a node or nullptr if node with the given id is not present.
|
||||
virtual Node* GetNode(NodeId id) const = 0;
|
||||
Node* GetNode(NodeId id) const;
|
||||
|
||||
// @return a value or nullptr if value with the given id is not present.
|
||||
virtual Value<TensorT>* GetValue(ValueId id) const = 0;
|
||||
Value* GetValue(ValueId id) const;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Graph manipulation functions are below
|
||||
@ -129,386 +122,61 @@ class Graph {
|
||||
// @return new node created in this graph
|
||||
// NOTE: nodes should be created in the topological order, e.g. node A that
|
||||
// depends on a value from node B should be created after node B.
|
||||
virtual Node* NewNode() = 0;
|
||||
Node* NewNode();
|
||||
|
||||
// Insert Node after another in the execution plan.
|
||||
absl::Status InsertNodeAfter(NodeId id, Node** new_node);
|
||||
|
||||
// @return new value created in this graph
|
||||
virtual Value<TensorT>* NewValue() = 0;
|
||||
Value* NewValue();
|
||||
|
||||
// Sets a producer for the given value. There could be a single producer
|
||||
// for a value. If a value had another producer, it will reassign producer
|
||||
// appropriately. If a value didn't have a producer, it will be removed
|
||||
// from a graph's input.
|
||||
virtual absl::Status SetProducer(NodeId producer, ValueId value) = 0;
|
||||
absl::Status SetProducer(NodeId producer, ValueId value);
|
||||
|
||||
// Removes a producer for the given value. Value becomes producer-less and
|
||||
// therefore becomes graph's input.
|
||||
virtual absl::Status RemoveProducer(ValueId value) = 0;
|
||||
absl::Status RemoveProducer(ValueId value);
|
||||
|
||||
// Sets a consumer for the given value. There could be multiple consumers
|
||||
// for a value.
|
||||
virtual absl::Status AddConsumer(NodeId consumer, ValueId value) = 0;
|
||||
absl::Status AddConsumer(NodeId consumer, ValueId value);
|
||||
|
||||
// Replace input value for given node.
|
||||
virtual absl::Status ReplaceInput(NodeId node, ValueId old_value,
|
||||
ValueId new_value) = 0;
|
||||
absl::Status ReplaceInput(NodeId node, ValueId old_value, ValueId new_value);
|
||||
|
||||
// Removes a consumer for the given value. If value does not have any
|
||||
// consumers it becomes graph's output.
|
||||
virtual absl::Status RemoveConsumer(NodeId consumer, ValueId value) = 0;
|
||||
absl::Status RemoveConsumer(NodeId consumer, ValueId value);
|
||||
|
||||
// Removes node from this graph. For all input values this node will be
|
||||
// removed from consumers and for all output values a producer will be
|
||||
// removed.
|
||||
virtual absl::Status DeleteNode(NodeId id) = 0;
|
||||
absl::Status DeleteNode(NodeId id);
|
||||
|
||||
// Removes value from this graph. It will be removed from inputs for all
|
||||
// dependent nodes. A node that was a producer of this value will loose its
|
||||
// output.
|
||||
virtual absl::Status DeleteValue(ValueId id) = 0;
|
||||
};
|
||||
absl::Status DeleteValue(ValueId id);
|
||||
|
||||
// Implementation of a Graph interface. It keeps values and nodes referenced by
|
||||
// their index in a vector. Therefore, nodes and values are never deleted, but
|
||||
// rather erased, where corresponding index remains.
|
||||
//
|
||||
// It is possible to re-use removed indices, but it is not implemented yet.
|
||||
template <typename TensorT>
|
||||
class Model : public Graph<TensorT> {
|
||||
public:
|
||||
const std::string& name() const { return name_; }
|
||||
|
||||
void set_name(std::string name) { name_ = std::move(name); }
|
||||
|
||||
std::vector<Value<TensorT>*> values() const final {
|
||||
return FilterValues([](const ValueDef&) { return true; });
|
||||
}
|
||||
|
||||
// Returns nodes in the execution order.
|
||||
std::vector<Node*> nodes() const final {
|
||||
return FilterNodes([](const NodeDef&) { return true; });
|
||||
}
|
||||
|
||||
std::vector<Value<TensorT>*> inputs() const final {
|
||||
return FilterValues(
|
||||
[](const ValueDef& v) { return v.producer == nullptr; });
|
||||
}
|
||||
|
||||
std::vector<Value<TensorT>*> outputs() const final {
|
||||
return FilterValues([](const ValueDef& v) { return v.consumers.empty(); });
|
||||
}
|
||||
|
||||
bool IsGraphInput(ValueId id) const final {
|
||||
if (id >= values_.size()) {
|
||||
return false;
|
||||
}
|
||||
return values_[id].producer == nullptr;
|
||||
}
|
||||
|
||||
bool IsGraphOutput(ValueId id) const final {
|
||||
if (id >= values_.size()) {
|
||||
return false;
|
||||
}
|
||||
return values_[id].consumers.empty();
|
||||
}
|
||||
|
||||
Node* GetNode(NodeId id) const final {
|
||||
if (id >= nodes_.size()) {
|
||||
return {};
|
||||
}
|
||||
return nodes_.at(id).node.get();
|
||||
}
|
||||
|
||||
Value<TensorT>* GetValue(ValueId id) const final {
|
||||
if (id >= values_.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
return values_[id].value.get();
|
||||
}
|
||||
|
||||
// Append Node to the end of the execution plan.
|
||||
Node* NewNode() final {
|
||||
const NodeId new_id = nodes_.size();
|
||||
NodeDef def;
|
||||
def.node = absl::make_unique<Node>(Node{static_cast<NodeId>(new_id), {}});
|
||||
Node* node = def.node.get();
|
||||
nodes_[new_id] = std::move(def);
|
||||
execution_plan_.push_back(new_id);
|
||||
return node;
|
||||
}
|
||||
|
||||
// Insert Node after another in the execution plan.
|
||||
absl::Status InsertNodeAfter(NodeId id, Node** new_node) {
|
||||
if (id >= nodes_.size()) {
|
||||
return absl::OutOfRangeError("NodeId is out of range");
|
||||
}
|
||||
int idx = 0;
|
||||
while (idx < execution_plan_.size()) {
|
||||
if (execution_plan_[idx] == id) break;
|
||||
++idx;
|
||||
}
|
||||
if (idx == execution_plan_.size()) {
|
||||
return absl::OutOfRangeError("NodeId not in execution plan");
|
||||
}
|
||||
|
||||
const NodeId new_id = nodes_.size();
|
||||
NodeDef def;
|
||||
def.node = absl::make_unique<Node>(Node{static_cast<NodeId>(new_id), {}});
|
||||
*new_node = def.node.get();
|
||||
nodes_[new_id] = std::move(def);
|
||||
execution_plan_.insert(execution_plan_.begin() + idx + 1, new_id);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Value<TensorT>* NewValue() final {
|
||||
ValueDef def;
|
||||
def.value = absl::make_unique<Value<TensorT>>(
|
||||
Value<TensorT>{static_cast<ValueId>(values_.size()), {}});
|
||||
Value<TensorT>* value = def.value.get();
|
||||
values_.push_back(std::move(def));
|
||||
return value;
|
||||
}
|
||||
|
||||
std::vector<Value<TensorT>*> FindInputs(NodeId id) const final {
|
||||
if (id >= nodes_.size()) {
|
||||
return {};
|
||||
}
|
||||
return nodes_.at(id).inputs;
|
||||
}
|
||||
|
||||
std::vector<Value<TensorT>*> FindOutputs(NodeId id) const final {
|
||||
if (id >= nodes_.size()) {
|
||||
return {};
|
||||
}
|
||||
return nodes_.at(id).outputs;
|
||||
}
|
||||
|
||||
Node* FindProducer(ValueId id) const final {
|
||||
if (id >= values_.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
return values_[id].producer;
|
||||
}
|
||||
|
||||
std::vector<Node*> FindConsumers(ValueId id) const final {
|
||||
if (id >= values_.size()) {
|
||||
return {};
|
||||
}
|
||||
return values_[id].consumers;
|
||||
}
|
||||
|
||||
absl::Status SetProducer(NodeId producer, ValueId value) final {
|
||||
ValueDef* v;
|
||||
RETURN_IF_ERROR(LookupValue(value, &v));
|
||||
Value<TensorT>* value_ptr = v->value.get();
|
||||
NodeDef* n;
|
||||
RETURN_IF_ERROR(LookupNode(producer, &n));
|
||||
Node* node_ptr = n->node.get();
|
||||
|
||||
// check if this value has the same producer already
|
||||
if (node_ptr == v->producer) {
|
||||
return absl::AlreadyExistsError(absl::StrCat(
|
||||
"Node ", producer, " is already a producer of the value ", value));
|
||||
}
|
||||
|
||||
// Check if the node is a consumer of this value.
|
||||
if (IsInput(producer, value)) {
|
||||
return absl::InvalidArgumentError("Node is a consumer of the value");
|
||||
}
|
||||
// TODO(akulik): detect circular dependency?
|
||||
|
||||
if (v->producer != nullptr) {
|
||||
// value is no longer produced by it's previous producer.
|
||||
Erase(&nodes_[v->producer->id].outputs, value_ptr);
|
||||
}
|
||||
v->producer = node_ptr;
|
||||
n->outputs.push_back(value_ptr);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status RemoveProducer(ValueId value) final {
|
||||
ValueDef* v;
|
||||
RETURN_IF_ERROR(LookupValue(value, &v));
|
||||
Value<TensorT>* value_ptr = v->value.get();
|
||||
if (v->producer == nullptr) {
|
||||
return absl::InvalidArgumentError("Value does not have a producer");
|
||||
}
|
||||
Erase(&nodes_[v->producer->id].outputs, value_ptr);
|
||||
v->producer = nullptr;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ReplaceInput(NodeId node, ValueId old_value,
|
||||
ValueId new_value) final {
|
||||
ValueDef* v_old;
|
||||
RETURN_IF_ERROR(LookupValue(old_value, &v_old));
|
||||
Value<TensorT>* value_old_ptr = v_old->value.get();
|
||||
ValueDef* v_new;
|
||||
RETURN_IF_ERROR(LookupValue(new_value, &v_new));
|
||||
Value<TensorT>* value_new_ptr = v_new->value.get();
|
||||
NodeDef* n;
|
||||
RETURN_IF_ERROR(LookupNode(node, &n));
|
||||
Node* node_ptr = n->node.get();
|
||||
|
||||
// Check if the node is a consumer of old_value.
|
||||
if (!IsInput(node, old_value)) {
|
||||
return absl::InvalidArgumentError("old_value must be input of node.");
|
||||
}
|
||||
|
||||
// Check if the node is not a consumer of new_value.
|
||||
if (IsInput(node, new_value)) {
|
||||
return absl::InvalidArgumentError("new_value can not be input of node.");
|
||||
}
|
||||
|
||||
// Check if this value has the same producer already
|
||||
if (node_ptr == v_new->producer) {
|
||||
return absl::InvalidArgumentError("new_value can not be output of node.");
|
||||
}
|
||||
|
||||
for (int i = 0; i < n->inputs.size(); ++i) {
|
||||
if (n->inputs[i] == value_old_ptr) {
|
||||
n->inputs[i] = value_new_ptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
v_new->consumers.push_back(node_ptr);
|
||||
Erase(&v_old->consumers, node_ptr);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AddConsumer(NodeId consumer, ValueId value) final {
|
||||
ValueDef* v;
|
||||
RETURN_IF_ERROR(LookupValue(value, &v));
|
||||
Value<TensorT>* value_ptr = v->value.get();
|
||||
NodeDef* n;
|
||||
RETURN_IF_ERROR(LookupNode(consumer, &n));
|
||||
Node* node_ptr = n->node.get();
|
||||
|
||||
// check if this value has the same producer already
|
||||
if (node_ptr == v->producer) {
|
||||
return absl::InvalidArgumentError("Node is a producer of the value");
|
||||
}
|
||||
|
||||
// check if this value has the same consumer already
|
||||
if (IsInput(consumer, value)) {
|
||||
return absl::AlreadyExistsError(absl::StrCat(
|
||||
"Node ", consumer, " is already a consumer of the value ", value));
|
||||
}
|
||||
|
||||
n->inputs.push_back(value_ptr);
|
||||
v->consumers.push_back(node_ptr);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status RemoveConsumer(NodeId consumer, ValueId value) final {
|
||||
ValueDef* v;
|
||||
RETURN_IF_ERROR(LookupValue(value, &v));
|
||||
Value<TensorT>* value_ptr = v->value.get();
|
||||
NodeDef* n;
|
||||
RETURN_IF_ERROR(LookupNode(consumer, &n));
|
||||
Node* node_ptr = n->node.get();
|
||||
if (!IsInput(consumer, value)) {
|
||||
return absl::InvalidArgumentError("Node is not a consumer of the value");
|
||||
}
|
||||
Erase(&n->inputs, value_ptr);
|
||||
Erase(&v->consumers, node_ptr);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status DeleteNode(NodeId id) final {
|
||||
NodeDef* n;
|
||||
RETURN_IF_ERROR(LookupNode(id, &n));
|
||||
Node* node_ptr = n->node.get();
|
||||
for (auto value : n->inputs) {
|
||||
Erase(&values_[value->id].consumers, node_ptr);
|
||||
}
|
||||
for (auto value : n->outputs) {
|
||||
values_[value->id].producer = nullptr;
|
||||
}
|
||||
n->inputs.clear();
|
||||
n->outputs.clear();
|
||||
n->node.reset();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status DeleteValue(ValueId id) final {
|
||||
ValueDef* v;
|
||||
RETURN_IF_ERROR(LookupValue(id, &v));
|
||||
Value<TensorT>* value_ptr = v->value.get();
|
||||
if (v->producer != nullptr) {
|
||||
Erase(&nodes_[v->producer->id].outputs, value_ptr);
|
||||
}
|
||||
if (!v->consumers.empty()) {
|
||||
for (auto node : v->consumers) {
|
||||
Erase(&nodes_[node->id].inputs, value_ptr);
|
||||
}
|
||||
}
|
||||
v->producer = nullptr;
|
||||
v->consumers.clear();
|
||||
v->value.reset();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status MakeExactCopy(Model<TensorT>* model) const {
|
||||
model->nodes_.clear();
|
||||
model->execution_plan_.clear();
|
||||
model->values_.clear();
|
||||
model->name_ = name_;
|
||||
for (auto& value_def : values_) {
|
||||
model->values_.push_back({});
|
||||
if (value_def.value) {
|
||||
model->values_.back().value =
|
||||
absl::make_unique<Value<TensorT>>(*value_def.value);
|
||||
}
|
||||
}
|
||||
// Add all nodes first.
|
||||
for (auto node_id : execution_plan_) {
|
||||
model->execution_plan_.push_back(node_id);
|
||||
model->nodes_[node_id] = {};
|
||||
auto& node_def = nodes_.at(node_id);
|
||||
if (node_def.node) {
|
||||
model->nodes_[node_id].node = absl::make_unique<Node>(*node_def.node);
|
||||
}
|
||||
}
|
||||
// Wire up dependencies between nodes.
|
||||
for (auto node_id : execution_plan_) {
|
||||
auto& node_def = nodes_.at(node_id);
|
||||
if (node_def.node) {
|
||||
for (auto output : node_def.outputs) {
|
||||
RETURN_IF_ERROR(model->SetProducer(node_def.node->id, output->id));
|
||||
}
|
||||
for (auto input : node_def.inputs) {
|
||||
RETURN_IF_ERROR(model->AddConsumer(node_def.node->id, input->id));
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status MakeExactCopy(GraphFloat32* model) const;
|
||||
|
||||
private:
|
||||
struct NodeDef {
|
||||
std::vector<Value<TensorT>*> inputs;
|
||||
std::vector<Value<TensorT>*> outputs;
|
||||
std::vector<Value*> inputs;
|
||||
std::vector<Value*> outputs;
|
||||
std::unique_ptr<Node> node;
|
||||
};
|
||||
|
||||
struct ValueDef {
|
||||
Node* producer = nullptr;
|
||||
std::vector<Node*> consumers;
|
||||
std::unique_ptr<Value<TensorT>> value;
|
||||
std::unique_ptr<Value> value;
|
||||
};
|
||||
|
||||
bool IsInput(NodeId node, ValueId value) {
|
||||
if (node >= nodes_.size() || value >= values_.size()) {
|
||||
return false;
|
||||
}
|
||||
const NodeDef& n = nodes_[node];
|
||||
const ValueDef& v = values_[value];
|
||||
if (!n.node || !v.value) {
|
||||
return false;
|
||||
}
|
||||
return std::find(n.inputs.begin(), n.inputs.end(), v.value.get()) !=
|
||||
n.inputs.end();
|
||||
}
|
||||
bool IsInput(NodeId node, ValueId value);
|
||||
|
||||
template <typename T>
|
||||
static void Erase(std::vector<T>* values, T value) {
|
||||
@ -516,34 +184,14 @@ class Model : public Graph<TensorT> {
|
||||
}
|
||||
|
||||
// @return non-nullptr NodeDef that has valid Node or an error
|
||||
absl::Status LookupNode(NodeId id, NodeDef** node_def) {
|
||||
if (id >= nodes_.size()) {
|
||||
return absl::OutOfRangeError("NodeId is out of range");
|
||||
}
|
||||
auto& n = nodes_[id];
|
||||
if (!n.node) {
|
||||
return absl::OutOfRangeError("Node is already deleted");
|
||||
}
|
||||
*node_def = &n;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status LookupNode(NodeId id, NodeDef** node_def);
|
||||
|
||||
// @return non-nullptr ValueDef that has valid Value or an error
|
||||
absl::Status LookupValue(ValueId id, ValueDef** value_def) {
|
||||
if (id >= values_.size()) {
|
||||
return absl::OutOfRangeError("ValueId is out of range");
|
||||
}
|
||||
auto& v = values_[id];
|
||||
if (!v.value) {
|
||||
return absl::OutOfRangeError("Value is already deleted");
|
||||
}
|
||||
*value_def = &v;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status LookupValue(ValueId id, ValueDef** value_def);
|
||||
|
||||
template <typename Pred>
|
||||
std::vector<Value<TensorT>*> FilterValues(const Pred& predicate) const {
|
||||
std::vector<Value<TensorT>*> values;
|
||||
std::vector<Value*> FilterValues(const Pred& predicate) const {
|
||||
std::vector<Value*> values;
|
||||
values.reserve(values_.size());
|
||||
for (auto& v : values_) {
|
||||
if (v.value != nullptr && predicate(v)) {
|
||||
@ -566,8 +214,6 @@ class Model : public Graph<TensorT> {
|
||||
return nodes;
|
||||
}
|
||||
|
||||
std::string name_;
|
||||
|
||||
// There are two approaches possible: wrap entire NodeDef and ValueDef into
|
||||
// unique_ptr and store it in values_ and nodes_ or store it by value.
|
||||
// We store it by value here to make introspection calls cheaper.
|
||||
@ -581,108 +227,27 @@ class Model : public Graph<TensorT> {
|
||||
// Removes to_remove node that precedes to_keep node only if to_remove has
|
||||
// outputs that are consumed only by to_keep. In such case to_keep inherits all
|
||||
// to_remove inputs.
|
||||
template <typename TensorT>
|
||||
absl::Status RemovePrecedingNode(Graph<TensorT>* graph, const Node* to_remove,
|
||||
const Node* to_keep) {
|
||||
// Make sure all outputs from to_remove are consumed by to_keep.
|
||||
for (auto output : graph->FindOutputs(to_remove->id)) {
|
||||
auto consumers = graph->FindConsumers(output->id);
|
||||
if (consumers.size() > 1 ||
|
||||
(consumers.size() == 1 && consumers[0] != to_keep)) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Output from to_remove node has other consumers");
|
||||
}
|
||||
}
|
||||
|
||||
// Update all references
|
||||
for (auto input : graph->FindInputs(to_remove->id)) {
|
||||
RETURN_IF_ERROR(graph->AddConsumer(to_keep->id, input->id));
|
||||
}
|
||||
for (auto output : graph->FindOutputs(to_remove->id)) {
|
||||
RETURN_IF_ERROR(graph->DeleteValue(output->id));
|
||||
}
|
||||
return graph->DeleteNode(to_remove->id);
|
||||
}
|
||||
absl::Status RemovePrecedingNode(GraphFloat32* graph, const Node* to_remove,
|
||||
const Node* to_keep);
|
||||
|
||||
// Removes to_remove node that follows to_keep node only if to_remove has inputs
|
||||
// that are produced by to_keep. to_keep inherits all to_remove inputs.
|
||||
template <typename TensorT>
|
||||
absl::Status RemoveFollowingNode(Graph<TensorT>* graph, const Node* to_remove,
|
||||
const Node* to_keep) {
|
||||
// Make sure all inputs to to_remove are produced by to_keep.
|
||||
for (auto input : graph->FindInputs(to_remove->id)) {
|
||||
Node* producer = graph->FindProducer(input->id);
|
||||
if (producer->id != to_keep->id) {
|
||||
return absl::InvalidArgumentError("To_remove node has other inputs");
|
||||
}
|
||||
}
|
||||
|
||||
for (auto input : graph->FindInputs(to_remove->id)) {
|
||||
RETURN_IF_ERROR(graph->DeleteValue(input->id));
|
||||
}
|
||||
for (auto output : graph->FindOutputs(to_remove->id)) {
|
||||
RETURN_IF_ERROR(graph->SetProducer(to_keep->id, output->id));
|
||||
}
|
||||
return graph->DeleteNode(to_remove->id);
|
||||
}
|
||||
absl::Status RemoveFollowingNode(GraphFloat32* graph, const Node* to_remove,
|
||||
const Node* to_keep);
|
||||
|
||||
// Removes to_remove node.
|
||||
// Requires that node has one input and one output;
|
||||
template <typename TensorT>
|
||||
absl::Status RemoveOneInputOneOutputNode(Graph<TensorT>* graph,
|
||||
const Node* to_remove) {
|
||||
auto inputs = graph->FindInputs(to_remove->id);
|
||||
auto outputs = graph->FindOutputs(to_remove->id);
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"To_remove node must have 1 input and 1 output");
|
||||
}
|
||||
auto input_id = inputs[0]->id;
|
||||
auto output_id = outputs[0]->id;
|
||||
Node* producer = graph->FindProducer(input_id);
|
||||
auto consumers = graph->FindConsumers(output_id);
|
||||
RETURN_IF_ERROR(graph->DeleteNode(to_remove->id));
|
||||
for (auto& consumer : consumers) {
|
||||
RETURN_IF_ERROR(graph->ReplaceInput(consumer->id, output_id, input_id));
|
||||
}
|
||||
RETURN_IF_ERROR(graph->DeleteValue(output_id));
|
||||
if (!producer && consumers.empty()) {
|
||||
RETURN_IF_ERROR(graph->DeleteValue(input_id));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status RemoveOneInputOneOutputNode(GraphFloat32* graph,
|
||||
const Node* to_remove);
|
||||
|
||||
template <typename TensorT>
|
||||
absl::Status AddOutput(Graph<TensorT>* graph, const Node* from_node,
|
||||
Value<TensorT>** output) {
|
||||
auto link = graph->NewValue();
|
||||
RETURN_IF_ERROR(graph->SetProducer(from_node->id, link->id));
|
||||
*output = link;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status AddOutput(GraphFloat32* graph, const Node* from_node,
|
||||
Value** output);
|
||||
|
||||
template <typename TensorT>
|
||||
absl::Status ConnectTwoNodes(Graph<TensorT>* graph, const Node* from_node,
|
||||
const Node* to_node, Value<TensorT>** output) {
|
||||
Value<TensorT>* link;
|
||||
RETURN_IF_ERROR(AddOutput(graph, from_node, &link));
|
||||
RETURN_IF_ERROR(graph->AddConsumer(to_node->id, link->id));
|
||||
*output = link;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
using GraphFloat32 = Model<TensorRef<BHWC>>;
|
||||
absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node,
|
||||
const Node* to_node, Value** output);
|
||||
|
||||
// @return true if all tensors have same batch value.
|
||||
inline bool IsBatchMatchesForAllValues(const GraphFloat32& model) {
|
||||
const int32_t b = model.values()[0]->tensor.shape.b;
|
||||
for (auto value : model.values()) {
|
||||
if (value->tensor.shape.b != b) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool IsBatchMatchesForAllValues(const GraphFloat32& model);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -66,12 +66,11 @@ namespace {
|
||||
// will turn into:
|
||||
// node(copy(output)) <- passthrough_node(output)
|
||||
absl::Status NewPassthroughNode(GraphFloat32* graph, Node* node,
|
||||
const Value<TensorRef<BHWC>>* output,
|
||||
Node** passthru_node) {
|
||||
const Value* output, Node** passthru_node) {
|
||||
*passthru_node = graph->NewNode();
|
||||
// Make copies for every output in the original node.
|
||||
RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id));
|
||||
Value<TensorRef<BHWC>>* copy_output = graph->NewValue();
|
||||
Value* copy_output = graph->NewValue();
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, copy_output->id));
|
||||
RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id));
|
||||
copy_output->tensor = output->tensor;
|
||||
@ -323,8 +322,7 @@ absl::Status CheckKernelsAndStrides(int kernel_h, int kernel_w, int strides_h,
|
||||
}
|
||||
|
||||
// Creates a simple node that holds tensor value.
|
||||
absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph,
|
||||
Value<TensorRef<BHWC>>** value) {
|
||||
absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, Value** value) {
|
||||
ConstTensorAttributes attr;
|
||||
attr.tensor = std::move(t);
|
||||
Node* node = graph->NewNode();
|
||||
@ -457,16 +455,16 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
|
||||
ConcatAttributes attr;
|
||||
// Read inputs first to make sure const node is added to a graph before
|
||||
// concat node to ensure topological order.
|
||||
std::vector<const Value<TensorRef<BHWC>>*> inputs;
|
||||
std::vector<const Value*> inputs;
|
||||
for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
|
||||
Value<TensorRef<BHWC>>* value;
|
||||
Value* value;
|
||||
const auto status = reader->ReadValue(idx, &value);
|
||||
if (status.ok()) {
|
||||
inputs.push_back(value);
|
||||
} else {
|
||||
TensorFloat32 tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
|
||||
Value<TensorRef<BHWC>>* value;
|
||||
Value* value;
|
||||
RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
|
||||
inputs.push_back(value);
|
||||
}
|
||||
@ -475,7 +473,7 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::CONCAT);
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
for (const Value<TensorRef<BHWC>>* input : inputs) {
|
||||
for (const Value* input : inputs) {
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
|
||||
}
|
||||
|
||||
@ -1011,7 +1009,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser {
|
||||
if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
|
||||
auto& reshape = node;
|
||||
conv = graph->NewNode(); // reset conv pointer!
|
||||
Value<TensorRef<BHWC>>* reshaped_value = graph->NewValue();
|
||||
Value* reshaped_value = graph->NewValue();
|
||||
reshaped_value->tensor.type = DataType::FLOAT32;
|
||||
reshaped_value->tensor.shape =
|
||||
BHWC(input->tensor.shape.b, 1, 1, weights.shape.w);
|
||||
@ -1121,11 +1119,11 @@ class LSTMOperationParser : public TFLiteOperationParser {
|
||||
lstm_attr.kernel_type = LstmKernelType::BASIC;
|
||||
lstm_node->operation.attributes = lstm_attr;
|
||||
|
||||
Value<TensorRef<BHWC>>* concat_temp;
|
||||
Value* concat_temp;
|
||||
int concat_tensor_idx = tflite_node->outputs->data[2];
|
||||
RETURN_IF_ERROR(
|
||||
reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
|
||||
Value<TensorRef<BHWC>>* activ_temp;
|
||||
Value* activ_temp;
|
||||
int activ_tensor_idx = tflite_node->outputs->data[3];
|
||||
RETURN_IF_ERROR(
|
||||
reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
|
||||
@ -1677,7 +1675,7 @@ class SliceOperationParser : public TFLiteOperationParser {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::SLICE);
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
Value<TensorRef<BHWC>>* input;
|
||||
Value* input;
|
||||
RETURN_IF_ERROR(reader->ReadValue(0, &input));
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
|
||||
|
||||
@ -1842,7 +1840,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::SLICE);
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
Value<TensorRef<BHWC>>* input;
|
||||
Value* input;
|
||||
RETURN_IF_ERROR(reader->ReadValue(0, &input));
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
|
||||
|
||||
@ -2027,7 +2025,7 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
|
||||
GraphFloat32* graph, ObjectReader* reader) final {
|
||||
auto* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
|
||||
Value<TensorRef<BHWC>>* input;
|
||||
Value* input;
|
||||
RETURN_IF_ERROR(reader->ReadValue(2, &input));
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
@ -2677,7 +2675,7 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops) {
|
||||
absl::Status PrecreateIOTensors(
|
||||
TfLiteContext* context, GraphFloat32* graph, TfLiteIntArray* io_tensors,
|
||||
std::unordered_map<int, int>* quant_conversion_map,
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value) {
|
||||
std::unordered_map<int, Value*>* tensor_to_value) {
|
||||
for (int i = 0; i < io_tensors->size; ++i) {
|
||||
const int tensor_index = io_tensors->data[i];
|
||||
const TfLiteTensor& tflite_tensor = context->tensors[tensor_index];
|
||||
@ -2717,7 +2715,7 @@ absl::Status BuildModel(TfLiteContext* context,
|
||||
operations.push_back(std::move(op_parser));
|
||||
tflite_nodes.push_back(i);
|
||||
}
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*> tensor_to_value;
|
||||
std::unordered_map<int, Value*> tensor_to_value;
|
||||
RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
|
||||
delegate_params->input_tensors,
|
||||
quant_conversion_map, &tensor_to_value));
|
||||
|
@ -32,8 +32,8 @@ TEST(Model, SingleNode) {
|
||||
// graph_input -> node -> graph_output
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value* graph_input = graph.NewValue();
|
||||
Value* graph_output = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||
|
||||
@ -53,9 +53,9 @@ TEST(Model, SingleNodeMultipleOutputs) {
|
||||
// graph_input -> node -> (graph_output1, graph_output2)
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output1 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output2 = graph.NewValue();
|
||||
Value* graph_input = graph.NewValue();
|
||||
Value* graph_output1 = graph.NewValue();
|
||||
Value* graph_output2 = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, graph_output1->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, graph_output2->id).ok());
|
||||
@ -68,7 +68,7 @@ TEST(Model, SingleNodeMultipleOutputs) {
|
||||
TEST(Model, SetSameConsumer) {
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value* graph_input = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||
EXPECT_FALSE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||
}
|
||||
@ -77,8 +77,8 @@ TEST(Model, RemoveConsumer) {
|
||||
// (graph_input1, graph_input2) -> node
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input1 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input2 = graph.NewValue();
|
||||
Value* graph_input1 = graph.NewValue();
|
||||
Value* graph_input2 = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input1->id).ok());
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input2->id).ok());
|
||||
EXPECT_THAT(graph.FindConsumers(graph_input1->id),
|
||||
@ -102,7 +102,7 @@ TEST(Model, RemoveConsumer) {
|
||||
TEST(Model, SetSameProducer) {
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value* graph_output = graph.NewValue();
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||
EXPECT_FALSE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||
}
|
||||
@ -110,10 +110,10 @@ TEST(Model, SetSameProducer) {
|
||||
TEST(Model, ReplaceInput) {
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* v0 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* v1 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* v2 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* v3 = graph.NewValue();
|
||||
Value* v0 = graph.NewValue();
|
||||
Value* v1 = graph.NewValue();
|
||||
Value* v2 = graph.NewValue();
|
||||
Value* v3 = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, v0->id).ok());
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, v1->id).ok());
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, v2->id).ok());
|
||||
@ -125,7 +125,7 @@ TEST(Model, ReplaceInput) {
|
||||
TEST(Model, RemoveProducer) {
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value* graph_output = graph.NewValue();
|
||||
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||
EXPECT_THAT(graph.inputs(), UnorderedElementsAre());
|
||||
@ -142,8 +142,8 @@ TEST(Model, RemoveProducer) {
|
||||
TEST(Model, RemoveSimpleNodeDegenerateCase) {
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value* graph_input = graph.NewValue();
|
||||
Value* graph_output = graph.NewValue();
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||
@ -161,9 +161,9 @@ TEST(Model, RemoveSimpleNodeNoPreviousNode) {
|
||||
GraphFloat32 graph;
|
||||
Node* simple_node = graph.NewNode();
|
||||
Node* consumer_node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
Value* graph_input = graph.NewValue();
|
||||
Value* graph_output = graph.NewValue();
|
||||
Value* value = graph.NewValue();
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(simple_node->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(simple_node->id, value->id).ok());
|
||||
@ -183,9 +183,9 @@ TEST(Model, RemoveSimpleNodeNoAfterNodes) {
|
||||
GraphFloat32 graph;
|
||||
Node* simple_node = graph.NewNode();
|
||||
Node* producer_node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
Value* graph_input = graph.NewValue();
|
||||
Value* graph_output = graph.NewValue();
|
||||
Value* value = graph.NewValue();
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(simple_node->id, value->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(simple_node->id, graph_output->id).ok());
|
||||
@ -206,10 +206,10 @@ TEST(Model, RemoveSimpleNodeGeneralCase) {
|
||||
Node* simple_node = graph.NewNode();
|
||||
Node* producer_node = graph.NewNode();
|
||||
Node* consumer_node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value0 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value1 = graph.NewValue();
|
||||
Value* graph_input = graph.NewValue();
|
||||
Value* graph_output = graph.NewValue();
|
||||
Value* value0 = graph.NewValue();
|
||||
Value* value1 = graph.NewValue();
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(producer_node->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(producer_node->id, value0->id).ok());
|
||||
@ -257,12 +257,12 @@ TEST(Model, RemoveSimpleNodeComplexCase) {
|
||||
Node* n0 = graph.NewNode();
|
||||
Node* n1 = graph.NewNode(); // node to remove
|
||||
Node* n2 = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* v0 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* v1 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* v2 = graph.NewValue(); // value to be removed
|
||||
Value<TensorRef<BHWC>>* v3 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* o1 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* o2 = graph.NewValue();
|
||||
Value* v0 = graph.NewValue();
|
||||
Value* v1 = graph.NewValue();
|
||||
Value* v2 = graph.NewValue(); // value to be removed
|
||||
Value* v3 = graph.NewValue();
|
||||
Value* o1 = graph.NewValue();
|
||||
Value* o2 = graph.NewValue();
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(n0->id, v0->id).ok());
|
||||
ASSERT_TRUE(graph.AddConsumer(n0->id, v1->id).ok());
|
||||
@ -289,14 +289,14 @@ TEST(Model, CircularDependency) {
|
||||
{
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
Value* value = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, value->id).ok());
|
||||
EXPECT_FALSE(graph.SetProducer(node->id, value->id).ok());
|
||||
}
|
||||
{
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
Value* value = graph.NewValue();
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, value->id).ok());
|
||||
EXPECT_FALSE(graph.AddConsumer(node->id, value->id).ok());
|
||||
}
|
||||
@ -309,8 +309,8 @@ TEST(Model, ReassignValue) {
|
||||
GraphFloat32 graph;
|
||||
Node* node1 = graph.NewNode();
|
||||
Node* node2 = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value* graph_input = graph.NewValue();
|
||||
Value* graph_output = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node1->id, graph_output->id).ok());
|
||||
ASSERT_TRUE(graph.AddConsumer(node2->id, graph_input->id).ok());
|
||||
@ -336,9 +336,9 @@ TEST(Model, DeleteValue) {
|
||||
GraphFloat32 graph;
|
||||
Node* node1 = graph.NewNode();
|
||||
Node* node2 = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
Value* graph_input = graph.NewValue();
|
||||
Value* graph_output = graph.NewValue();
|
||||
Value* value = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
|
||||
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
|
||||
@ -377,10 +377,10 @@ TEST(Model, DeleteNode) {
|
||||
Node* node1 = graph.NewNode();
|
||||
Node* node2 = graph.NewNode();
|
||||
Node* node3 = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output2 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
Value* graph_input = graph.NewValue();
|
||||
Value* graph_output = graph.NewValue();
|
||||
Value* graph_output2 = graph.NewValue();
|
||||
Value* value = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
|
||||
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
|
||||
@ -437,9 +437,9 @@ TEST(Model, InsertNodeAfter) {
|
||||
GraphFloat32 graph;
|
||||
Node* node1 = graph.NewNode();
|
||||
Node* node2 = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
Value* graph_input = graph.NewValue();
|
||||
Value* graph_output = graph.NewValue();
|
||||
Value* value = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
|
||||
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
|
||||
|
@ -28,10 +28,9 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
absl::Status ObjectReader::ReadNonConstantTensor(
|
||||
TfLiteContext* context,
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value,
|
||||
TfLiteContext* context, std::unordered_map<int, Value*>* tensor_to_value,
|
||||
std::unordered_map<int, int>* quant_conversion_map, GraphFloat32* graph,
|
||||
uint32_t tensor_idx, Value<TensorRef<BHWC>>** value) {
|
||||
uint32_t tensor_idx, Value** value) {
|
||||
if (tensor_idx >= context->tensors_size) {
|
||||
return absl::OutOfRangeError(
|
||||
absl::StrCat("ReadNonConstTensor: input tensor index: ", tensor_idx));
|
||||
@ -63,7 +62,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
||||
(*quant_conversion_map)[fp_tensor_index] = tensor_idx;
|
||||
(*quant_conversion_map)[tensor_idx] = fp_tensor_index;
|
||||
// Add a new GPU Value for the new dequantized floating-point tensor.
|
||||
Value<TensorRef<BHWC>>* value = graph->NewValue();
|
||||
Value* value = graph->NewValue();
|
||||
RETURN_IF_ERROR(
|
||||
ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, &value->tensor));
|
||||
value->tensor.ref = fp_tensor_index;
|
||||
@ -77,7 +76,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
||||
tensor_idx = quant_conversion_map->at(tensor_idx);
|
||||
} else {
|
||||
// Floating-point case.
|
||||
Value<TensorRef<BHWC>>* value = graph->NewValue();
|
||||
Value* value = graph->NewValue();
|
||||
RETURN_IF_ERROR(
|
||||
ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor));
|
||||
value->tensor.ref = tensor_idx;
|
||||
@ -91,8 +90,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ObjectReader::ReadValue(uint32_t idx,
|
||||
Value<TensorRef<BHWC>>** value) {
|
||||
absl::Status ObjectReader::ReadValue(uint32_t idx, Value** value) {
|
||||
if (idx >= node_->inputs->size) {
|
||||
return absl::OutOfRangeError(
|
||||
absl::StrCat("ReadValue: input tensor index: ", idx));
|
||||
@ -100,8 +98,8 @@ absl::Status ObjectReader::ReadValue(uint32_t idx,
|
||||
return ReadValueByTensorIdx(node_->inputs->data[idx], value);
|
||||
}
|
||||
|
||||
absl::Status ObjectReader::ReadValueByTensorIdx(
|
||||
uint32_t tensor_idx, Value<TensorRef<BHWC>>** value) {
|
||||
absl::Status ObjectReader::ReadValueByTensorIdx(uint32_t tensor_idx,
|
||||
Value** value) {
|
||||
// Constant tensors should be handled by ReadTensor.
|
||||
return ReadNonConstantTensor(context_, tensor_to_value_,
|
||||
quant_conversion_map_, graph_, tensor_idx,
|
||||
@ -133,7 +131,7 @@ absl::Status ObjectReader::AddOutput(const Node* node, int id) {
|
||||
node_->outputs->size));
|
||||
}
|
||||
int output_tensor_idx = node_->outputs->data[id];
|
||||
Value<TensorRef<BHWC>>* value;
|
||||
Value* value;
|
||||
RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value));
|
||||
RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id));
|
||||
return absl::OkStatus();
|
||||
@ -147,7 +145,7 @@ absl::Status ObjectReader::AddOutputs(const Node* node) {
|
||||
}
|
||||
|
||||
absl::Status ObjectReader::AddInput(const Node* node, uint32_t idx) {
|
||||
Value<TensorRef<BHWC>>* input;
|
||||
Value* input;
|
||||
RETURN_IF_ERROR(ReadValue(idx, &input));
|
||||
return graph_->AddConsumer(node->id, input->id);
|
||||
}
|
||||
|
@ -34,25 +34,23 @@ namespace gpu {
|
||||
class ObjectReader {
|
||||
public:
|
||||
static absl::Status ReadNonConstantTensor(
|
||||
TfLiteContext* context,
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value,
|
||||
TfLiteContext* context, std::unordered_map<int, Value*>* tensor_to_value,
|
||||
std::unordered_map<int, int>* quant_conversion_map, GraphFloat32* graph,
|
||||
uint32_t tensor_idx, Value<TensorRef<BHWC>>** value = nullptr);
|
||||
uint32_t tensor_idx, Value** value = nullptr);
|
||||
|
||||
ObjectReader(
|
||||
GraphFloat32* graph, TfLiteContext* context, const TfLiteNode* node,
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value,
|
||||
std::unordered_map<int, int>* quant_conversion_map = nullptr)
|
||||
ObjectReader(GraphFloat32* graph, TfLiteContext* context,
|
||||
const TfLiteNode* node,
|
||||
std::unordered_map<int, Value*>* tensor_to_value,
|
||||
std::unordered_map<int, int>* quant_conversion_map = nullptr)
|
||||
: graph_(graph),
|
||||
context_(context),
|
||||
node_(node),
|
||||
tensor_to_value_(tensor_to_value),
|
||||
quant_conversion_map_(quant_conversion_map) {}
|
||||
|
||||
absl::Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value);
|
||||
absl::Status ReadValue(uint32_t idx, Value** value);
|
||||
|
||||
absl::Status ReadValueByTensorIdx(uint32_t tensor_idx,
|
||||
Value<TensorRef<BHWC>>** value);
|
||||
absl::Status ReadValueByTensorIdx(uint32_t tensor_idx, Value** value);
|
||||
|
||||
int GetNumberOfRuntimeInputs() const;
|
||||
|
||||
@ -89,7 +87,7 @@ class ObjectReader {
|
||||
GraphFloat32* graph_;
|
||||
TfLiteContext* context_;
|
||||
const TfLiteNode* node_;
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value_;
|
||||
std::unordered_map<int, Value*>* tensor_to_value_;
|
||||
std::unordered_map<int, int>* quant_conversion_map_;
|
||||
};
|
||||
|
||||
|
@ -64,7 +64,7 @@ class AddQuantAdjustments : public NodeTransformation {
|
||||
|
||||
// Add one output Value for the new node.
|
||||
// The tensor information should rename the same.
|
||||
Value<TensorRef<BHWC>>* adjusted_value = graph->NewValue();
|
||||
Value* adjusted_value = graph->NewValue();
|
||||
adjusted_value->tensor = output_value->tensor;
|
||||
status =
|
||||
graph->SetProducer(quant_and_dequant_node->id, adjusted_value->id);
|
||||
|
@ -59,7 +59,7 @@ TEST(AddQuantAdjustments, OneNode) {
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
|
||||
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/2.0,
|
||||
/*scale=*/0.008);
|
||||
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||
@ -114,18 +114,18 @@ TEST(AddQuantAdjustments, GeneralCase) {
|
||||
|
||||
// Connections.
|
||||
ASSERT_TRUE(graph.AddConsumer(add1_node->id, input->id).ok());
|
||||
Value<TensorRef<BHWC>>* link1;
|
||||
Value* link1;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, add1_node, quant_node, &link1).ok());
|
||||
AddQuantParams(&link1->quant_params, /*min=*/0.0, /*max=*/2.0,
|
||||
/*scale=*/0.008);
|
||||
link1->tensor.shape = BHWC(1, 4, 4, 8);
|
||||
ASSERT_TRUE(graph.AddConsumer(add2_node->id, link1->id).ok());
|
||||
Value<TensorRef<BHWC>>* link2;
|
||||
Value* link2;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, quant_node, add2_node, &link2).ok());
|
||||
AddQuantParams(&link2->quant_params, /*min=*/-1.0, /*max=*/1.0,
|
||||
/*scale=*/0.008);
|
||||
link2->tensor.shape = BHWC(1, 4, 4, 8);
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, add2_node, &output).ok());
|
||||
AddQuantParams(&output->quant_params, /*min=*/-1.0, /*max=*/1.0,
|
||||
/*scale=*/0.008);
|
||||
|
@ -57,11 +57,11 @@ TEST(MergeConvolutionWithAddTest, Smoke) {
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok());
|
||||
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
Value<TensorRef<BHWC>>* link1;
|
||||
Value* link1;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, add_node, &link1).ok());
|
||||
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
@ -108,11 +108,11 @@ TEST(MergeAddWithConvolutionTest, Smoke) {
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
|
||||
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
Value<TensorRef<BHWC>>* link1;
|
||||
Value* link1;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, add_node, conv_node, &link1).ok());
|
||||
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
|
@ -58,11 +58,11 @@ TEST(MergeConvolutionWithMulTest, Smoke) {
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok());
|
||||
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, mul_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
Value<TensorRef<BHWC>>* link1;
|
||||
Value* link1;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, mul_node, &link1).ok());
|
||||
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
@ -109,11 +109,11 @@ TEST(MergeMulWithConvolutionTest, Smoke) {
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(mul_node->id, input->id).ok());
|
||||
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
Value<TensorRef<BHWC>>* link1;
|
||||
Value* link1;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, mul_node, conv_node, &link1).ok());
|
||||
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
|
@ -68,16 +68,16 @@ TEST(MakeFullyConnected, Smoke) {
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(conv1x1_node0->id, input->id).ok());
|
||||
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, conv1x1_node2, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 1, 1, 32);
|
||||
|
||||
Value<TensorRef<BHWC>>* link1;
|
||||
Value* link1;
|
||||
ASSERT_TRUE(
|
||||
ConnectTwoNodes(&graph, conv1x1_node0, conv4x4_node1, &link1).ok());
|
||||
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
Value<TensorRef<BHWC>>* link2;
|
||||
Value* link2;
|
||||
ASSERT_TRUE(
|
||||
ConnectTwoNodes(&graph, conv4x4_node1, conv1x1_node2, &link2).ok());
|
||||
link2->tensor.shape = BHWC(1, 1, 1, 16);
|
||||
|
@ -38,7 +38,7 @@ TEST(MakePadding, Smoke) {
|
||||
attr.axis = Axis::HEIGHT;
|
||||
concat_node->operation.attributes = attr;
|
||||
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, concat_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 7, 3, 5);
|
||||
|
||||
@ -50,7 +50,7 @@ TEST(MakePadding, Smoke) {
|
||||
std::vector<float>(const_attr.tensor.shape.DimensionsProduct(), 0);
|
||||
const_node->operation.attributes = const_attr;
|
||||
|
||||
Value<TensorRef<BHWC>>* const_link;
|
||||
Value* const_link;
|
||||
ASSERT_TRUE(
|
||||
ConnectTwoNodes(&graph, const_node, concat_node, &const_link).ok());
|
||||
const_link->tensor.shape = const_attr.tensor.shape;
|
||||
|
@ -40,7 +40,7 @@ TEST(MergePaddingWith, Smoke) {
|
||||
pad_node->operation.attributes = attr;
|
||||
|
||||
auto conv_node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* temp;
|
||||
Value* temp;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node, conv_node, &temp).ok());
|
||||
ASSERT_TRUE(AddOutput(&graph, conv_node, &temp).ok());
|
||||
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
||||
@ -77,7 +77,7 @@ TEST(MergePaddingWith, MergeTwo) {
|
||||
pad_node1->operation.attributes = attr;
|
||||
|
||||
auto pad_node2 = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* temp;
|
||||
Value* temp;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node1, pad_node2, &temp).ok());
|
||||
pad_node2->operation.type = ToString(OperationType::PAD);
|
||||
attr.prepended = BHWC(0, 0, 0, 0);
|
||||
|
@ -35,12 +35,12 @@ TEST(RemoveSingleInputAdd, Smoke) {
|
||||
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
|
||||
|
||||
auto add_node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||
add_node->operation.type = ToString(OperationType::ADD);
|
||||
add_node->operation.attributes = AddAttributes();
|
||||
|
||||
Value<TensorRef<BHWC>>* temp;
|
||||
Value* temp;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
|
||||
ASSERT_EQ(2, graph.nodes().size());
|
||||
ASSERT_EQ(3, graph.values().size());
|
||||
@ -63,14 +63,14 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Tensor) {
|
||||
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
|
||||
|
||||
auto add_node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||
add_node->operation.type = ToString(OperationType::ADD);
|
||||
AddAttributes attr;
|
||||
attr.param = Tensor<Linear, DataType::FLOAT32>();
|
||||
add_node->operation.attributes = attr;
|
||||
|
||||
Value<TensorRef<BHWC>>* temp;
|
||||
Value* temp;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
|
||||
ASSERT_EQ(2, graph.nodes().size());
|
||||
ASSERT_EQ(3, graph.values().size());
|
||||
@ -90,14 +90,14 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Scalar) {
|
||||
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
|
||||
|
||||
auto add_node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||
add_node->operation.type = ToString(OperationType::ADD);
|
||||
AddAttributes attr;
|
||||
attr.param = 0.5f;
|
||||
add_node->operation.attributes = attr;
|
||||
|
||||
Value<TensorRef<BHWC>>* temp;
|
||||
Value* temp;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
|
||||
ASSERT_EQ(2, graph.nodes().size());
|
||||
ASSERT_EQ(3, graph.values().size());
|
||||
@ -119,11 +119,11 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Multiple) {
|
||||
ASSERT_TRUE(graph.AddConsumer(node_b->id, input->id).ok());
|
||||
|
||||
auto add_node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||
add_node->operation.type = ToString(OperationType::ADD);
|
||||
|
||||
Value<TensorRef<BHWC>>* temp;
|
||||
Value* temp;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, node_a, add_node, &temp).ok());
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, node_b, add_node, &temp).ok());
|
||||
ASSERT_EQ(3, graph.nodes().size());
|
||||
@ -144,7 +144,7 @@ TEST(RemoveDegenerateUpsampling, Smoke) {
|
||||
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
|
||||
|
||||
auto node_to_remove = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
Value* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 5, 5, 1);
|
||||
node_to_remove->operation.type = ToString(OperationType::RESIZE);
|
||||
@ -153,7 +153,7 @@ TEST(RemoveDegenerateUpsampling, Smoke) {
|
||||
attr.type = SamplingType::BILINEAR;
|
||||
node_to_remove->operation.attributes = attr;
|
||||
|
||||
Value<TensorRef<BHWC>>* link;
|
||||
Value* link;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok());
|
||||
link->tensor.shape = output->tensor.shape;
|
||||
ASSERT_EQ(2, graph.nodes().size());
|
||||
@ -175,10 +175,10 @@ TEST(RemoveIdentityReshape, Smoke) {
|
||||
Node* simple_node = graph.NewNode();
|
||||
Node* producer_node = graph.NewNode();
|
||||
Node* consumer_node = graph.NewNode();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value0 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value1 = graph.NewValue();
|
||||
Value* graph_input = graph.NewValue();
|
||||
Value* graph_output = graph.NewValue();
|
||||
Value* value0 = graph.NewValue();
|
||||
Value* value1 = graph.NewValue();
|
||||
|
||||
value0->tensor.shape = BHWC(1, 1, 1, 11);
|
||||
simple_node->operation.type = ToString(OperationType::RESHAPE);
|
||||
|
@ -580,8 +580,7 @@ class InferenceBuilderImpl : public InferenceBuilder {
|
||||
|
||||
private:
|
||||
// Links internal tensors with external user-facing objects.
|
||||
std::vector<TensorTieDef> LinkTensors(
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& values) {
|
||||
std::vector<TensorTieDef> LinkTensors(const std::vector<Value*>& values) {
|
||||
std::vector<TensorTieDef> links;
|
||||
links.reserve(values.size());
|
||||
for (const auto& value : values) {
|
||||
|
@ -145,7 +145,7 @@ class Delegate {
|
||||
|
||||
// TODO(impjdi): Remove code duplication.
|
||||
auto values = graph.values();
|
||||
auto find_value = [&](int tensor_index) -> Value<TensorRef<BHWC>>* {
|
||||
auto find_value = [&](int tensor_index) -> Value* {
|
||||
for (auto value : values) {
|
||||
if (value->tensor.ref == tensor_index) return value;
|
||||
}
|
||||
|
@ -54,20 +54,17 @@ namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
|
||||
bool IsWidthBroadcastedForSecondInput(
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
|
||||
bool IsWidthBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
|
||||
return inputs.size() == 2 &&
|
||||
inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w &&
|
||||
inputs[1]->tensor.shape.w == 1;
|
||||
}
|
||||
bool IsHeightBroadcastedForSecondInput(
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
|
||||
bool IsHeightBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
|
||||
return inputs.size() == 2 &&
|
||||
inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h &&
|
||||
inputs[1]->tensor.shape.h == 1;
|
||||
}
|
||||
bool IsChannelsBroadcastedForSecondInput(
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
|
||||
bool IsChannelsBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
|
||||
return inputs.size() == 2 &&
|
||||
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c &&
|
||||
inputs[1]->tensor.shape.c == 1;
|
||||
|
@ -239,7 +239,7 @@ class Delegate {
|
||||
|
||||
// TODO(impjdi): Remove code duplication.
|
||||
auto values = graph.values();
|
||||
auto find_value = [&](int tensor_index) -> Value<TensorRef<BHWC>>* {
|
||||
auto find_value = [&](int tensor_index) -> Value* {
|
||||
for (auto value : values) {
|
||||
if (value->tensor.ref == tensor_index) return value;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user