Convert template classes to simple classes:

* Value<TensorRef<BHWC>> is now Value.
* Graph<TensorRef<BHWC>> -> Model<TensorRef<BHWC>> -> typedef GraphFloat32 is now a single class GraphFloat32.

PiperOrigin-RevId: 306926255
Change-Id: I6730c1a999504c7484aec16ecf4b6aa213607c70
This commit is contained in:
Juhyun Lee 2020-04-16 14:51:37 -07:00 committed by TensorFlower Gardener
parent 339eab07e8
commit 22350e7ca8
27 changed files with 648 additions and 648 deletions

View File

@ -661,9 +661,8 @@ class InferenceBuilderImpl : public InferenceBuilder {
}
// Links internal tensors with external user-facing objects.
std::vector<TensorTieDef> LinkTensors(
const GraphFloat32& graph,
const std::vector<Value<TensorRef<BHWC>>*>& values) {
std::vector<TensorTieDef> LinkTensors(const GraphFloat32& graph,
const std::vector<Value*>& values) {
std::vector<TensorTieDef> links;
links.reserve(values.size());
for (const auto& value : values) {

View File

@ -119,9 +119,8 @@ bool IsBufferBased(const TensorStorageType& type) {
// Generic add is add that have several runtime inputs and they are not
// broadcasted, i.e. pointwise add for N tensors where N > 1.
bool IsGenericAdd(const Node& node,
const std::vector<Value<TensorRef<BHWC>>*>& inputs,
const std::vector<Value<TensorRef<BHWC>>*>& outputs) {
bool IsGenericAdd(const Node& node, const std::vector<Value*>& inputs,
const std::vector<Value*>& outputs) {
if (inputs.size() == 1) {
return false;
}

View File

@ -30,9 +30,8 @@ namespace cl {
absl::Status SelectDefault(const CreationContext& creation_context,
const OperationDef& op_def, ModelHints hints,
const std::vector<Value<TensorRef<BHWC>>*>& inputs,
const std::vector<Value<TensorRef<BHWC>>*>& outputs,
const Node& node,
const std::vector<Value*>& inputs,
const std::vector<Value*>& outputs, const Node& node,
GPUOperationsSubgraph* gpu_subgraph) {
return absl::UnimplementedError(
absl::StrCat("No selector for ", node.operation.type));

View File

@ -31,9 +31,8 @@ namespace cl {
absl::Status SelectDefault(const CreationContext& creation_context,
const OperationDef& op_def, ModelHints hints,
const std::vector<Value<TensorRef<BHWC>>*>& inputs,
const std::vector<Value<TensorRef<BHWC>>*>& outputs,
const Node& node,
const std::vector<Value*>& inputs,
const std::vector<Value*>& outputs, const Node& node,
GPUOperationsSubgraph* gpu_subgraph);
} // namespace cl

View File

@ -37,20 +37,17 @@ namespace gpu {
namespace cl {
namespace {
bool IsWidthBroadcastedForSecondInput(
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
bool IsWidthBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w &&
inputs[1]->tensor.shape.w == 1;
}
bool IsHeightBroadcastedForSecondInput(
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
bool IsHeightBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h &&
inputs[1]->tensor.shape.h == 1;
}
bool IsChannelsBroadcastedForSecondInput(
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
bool IsChannelsBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c &&
inputs[1]->tensor.shape.c == 1;
@ -146,11 +143,12 @@ absl::Status WinogradFromNode(const CreationContext& creation_context,
} // namespace
absl::Status GPUOperationFromNode(
const CreationContext& creation_context, const OperationDef& op_def,
ModelHints hints, const std::vector<Value<TensorRef<BHWC>>*>& inputs,
const std::vector<Value<TensorRef<BHWC>>*>& outputs, const Node& node,
GPUOperationsSubgraph* gpu_subgraph) {
absl::Status GPUOperationFromNode(const CreationContext& creation_context,
const OperationDef& op_def, ModelHints hints,
const std::vector<Value*>& inputs,
const std::vector<Value*>& outputs,
const Node& node,
GPUOperationsSubgraph* gpu_subgraph) {
std::unique_ptr<GPUOperation>* gpu_op =
InitSingleOpSubgraph(inputs, outputs, gpu_subgraph);
auto op_type = OperationTypeFromString(node.operation.type);

View File

@ -29,11 +29,12 @@ namespace tflite {
namespace gpu {
namespace cl {
absl::Status GPUOperationFromNode(
const CreationContext& creation_context, const OperationDef& op_def,
ModelHints hints, const std::vector<Value<TensorRef<BHWC>>*>& inputs,
const std::vector<Value<TensorRef<BHWC>>*>& outputs, const Node& node,
GPUOperationsSubgraph* gpu_subgraph);
absl::Status GPUOperationFromNode(const CreationContext& creation_context,
const OperationDef& op_def, ModelHints hints,
const std::vector<Value*>& inputs,
const std::vector<Value*>& outputs,
const Node& node,
GPUOperationsSubgraph* gpu_subgraph);
} // namespace cl
} // namespace gpu

View File

@ -26,8 +26,7 @@ namespace gpu {
namespace cl {
std::unique_ptr<GPUOperation>* InitSingleOpSubgraph(
const std::vector<Value<TensorRef<BHWC>>*>& inputs,
const std::vector<Value<TensorRef<BHWC>>*>& outputs,
const std::vector<Value*>& inputs, const std::vector<Value*>& outputs,
GPUOperationsSubgraph* gpu_subgraph) {
gpu_subgraph->operations.clear();
gpu_subgraph->new_tensors.clear();

View File

@ -43,8 +43,7 @@ struct GPUOperationsSubgraph {
};
std::unique_ptr<GPUOperation>* InitSingleOpSubgraph(
const std::vector<Value<TensorRef<BHWC>>*>& inputs,
const std::vector<Value<TensorRef<BHWC>>*>& outputs,
const std::vector<Value*>& inputs, const std::vector<Value*>& outputs,
GPUOperationsSubgraph* gpu_subgraph);
} // namespace cl

View File

@ -78,6 +78,7 @@ cc_library(
cc_library(
name = "model",
srcs = ["model.cc"],
hdrs = ["model.h"],
deps = [
":data_type",

View File

@ -0,0 +1,451 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
namespace gpu {
std::vector<Node*> GraphFloat32::nodes() const {
return FilterNodes([](const NodeDef&) { return true; });
}
std::vector<Value*> GraphFloat32::values() const {
return FilterValues([](const ValueDef&) { return true; });
}
std::vector<Value*> GraphFloat32::inputs() const {
return FilterValues([](const ValueDef& v) { return v.producer == nullptr; });
}
std::vector<Value*> GraphFloat32::outputs() const {
return FilterValues([](const ValueDef& v) { return v.consumers.empty(); });
}
std::vector<Value*> GraphFloat32::FindInputs(NodeId id) const {
if (id >= nodes_.size()) {
return {};
}
return nodes_.at(id).inputs;
}
std::vector<Value*> GraphFloat32::FindOutputs(NodeId id) const {
if (id >= nodes_.size()) {
return {};
}
return nodes_.at(id).outputs;
}
bool GraphFloat32::IsGraphInput(ValueId id) const {
if (id >= values_.size()) {
return false;
}
return values_[id].producer == nullptr;
}
bool GraphFloat32::IsGraphOutput(ValueId id) const {
if (id >= values_.size()) {
return false;
}
return values_[id].consumers.empty();
}
Node* GraphFloat32::FindProducer(ValueId id) const {
if (id >= values_.size()) {
return nullptr;
}
return values_[id].producer;
}
std::vector<Node*> GraphFloat32::FindConsumers(ValueId id) const {
if (id >= values_.size()) {
return {};
}
return values_[id].consumers;
}
Node* GraphFloat32::GetNode(NodeId id) const {
if (id >= nodes_.size()) {
return {};
}
return nodes_.at(id).node.get();
}
Value* GraphFloat32::GetValue(ValueId id) const {
if (id >= values_.size()) {
return nullptr;
}
return values_[id].value.get();
}
Node* GraphFloat32::NewNode() {
const NodeId new_id = nodes_.size();
NodeDef def;
def.node = absl::make_unique<Node>(Node{static_cast<NodeId>(new_id), {}});
Node* node = def.node.get();
nodes_[new_id] = std::move(def);
execution_plan_.push_back(new_id);
return node;
}
absl::Status GraphFloat32::InsertNodeAfter(NodeId id, Node** new_node) {
if (id >= nodes_.size()) {
return absl::OutOfRangeError("NodeId is out of range");
}
int idx = 0;
while (idx < execution_plan_.size()) {
if (execution_plan_[idx] == id) break;
++idx;
}
if (idx == execution_plan_.size()) {
return absl::OutOfRangeError("NodeId not in execution plan");
}
const NodeId new_id = nodes_.size();
NodeDef def;
def.node = absl::make_unique<Node>(Node{static_cast<NodeId>(new_id), {}});
*new_node = def.node.get();
nodes_[new_id] = std::move(def);
execution_plan_.insert(execution_plan_.begin() + idx + 1, new_id);
return absl::OkStatus();
}
Value* GraphFloat32::NewValue() {
ValueDef def;
def.value =
absl::make_unique<Value>(Value{static_cast<ValueId>(values_.size()), {}});
Value* value = def.value.get();
values_.push_back(std::move(def));
return value;
}
absl::Status GraphFloat32::SetProducer(NodeId producer, ValueId value) {
ValueDef* v;
RETURN_IF_ERROR(LookupValue(value, &v));
Value* value_ptr = v->value.get();
NodeDef* n;
RETURN_IF_ERROR(LookupNode(producer, &n));
Node* node_ptr = n->node.get();
// check if this value has the same producer already
if (node_ptr == v->producer) {
return absl::AlreadyExistsError(absl::StrCat(
"Node ", producer, " is already a producer of the value ", value));
}
// Check if the node is a consumer of this value.
if (IsInput(producer, value)) {
return absl::InvalidArgumentError("Node is a consumer of the value");
}
if (v->producer != nullptr) {
// value is no longer produced by it's previous producer.
Erase(&nodes_[v->producer->id].outputs, value_ptr);
}
v->producer = node_ptr;
n->outputs.push_back(value_ptr);
return absl::OkStatus();
}
absl::Status GraphFloat32::RemoveProducer(ValueId value) {
ValueDef* v;
RETURN_IF_ERROR(LookupValue(value, &v));
Value* value_ptr = v->value.get();
if (v->producer == nullptr) {
return absl::InvalidArgumentError("Value does not have a producer");
}
Erase(&nodes_[v->producer->id].outputs, value_ptr);
v->producer = nullptr;
return absl::OkStatus();
}
absl::Status GraphFloat32::AddConsumer(NodeId consumer, ValueId value) {
ValueDef* v;
RETURN_IF_ERROR(LookupValue(value, &v));
Value* value_ptr = v->value.get();
NodeDef* n;
RETURN_IF_ERROR(LookupNode(consumer, &n));
Node* node_ptr = n->node.get();
// check if this value has the same producer already
if (node_ptr == v->producer) {
return absl::InvalidArgumentError("Node is a producer of the value");
}
// check if this value has the same consumer already
if (IsInput(consumer, value)) {
return absl::AlreadyExistsError(absl::StrCat(
"Node ", consumer, " is already a consumer of the value ", value));
}
n->inputs.push_back(value_ptr);
v->consumers.push_back(node_ptr);
return absl::OkStatus();
}
// Replace input value for given node.
absl::Status GraphFloat32::ReplaceInput(NodeId node, ValueId old_value,
ValueId new_value) {
ValueDef* v_old;
RETURN_IF_ERROR(LookupValue(old_value, &v_old));
Value* value_old_ptr = v_old->value.get();
ValueDef* v_new;
RETURN_IF_ERROR(LookupValue(new_value, &v_new));
Value* value_new_ptr = v_new->value.get();
NodeDef* n;
RETURN_IF_ERROR(LookupNode(node, &n));
Node* node_ptr = n->node.get();
// Check if the node is a consumer of old_value.
if (!IsInput(node, old_value)) {
return absl::InvalidArgumentError("old_value must be input of node.");
}
// Check if the node is not a consumer of new_value.
if (IsInput(node, new_value)) {
return absl::InvalidArgumentError("new_value can not be input of node.");
}
// Check if this value has the same producer already
if (node_ptr == v_new->producer) {
return absl::InvalidArgumentError("new_value can not be output of node.");
}
for (int i = 0; i < n->inputs.size(); ++i) {
if (n->inputs[i] == value_old_ptr) {
n->inputs[i] = value_new_ptr;
break;
}
}
v_new->consumers.push_back(node_ptr);
Erase(&v_old->consumers, node_ptr);
return absl::OkStatus();
}
absl::Status GraphFloat32::RemoveConsumer(NodeId consumer, ValueId value) {
ValueDef* v;
RETURN_IF_ERROR(LookupValue(value, &v));
Value* value_ptr = v->value.get();
NodeDef* n;
RETURN_IF_ERROR(LookupNode(consumer, &n));
Node* node_ptr = n->node.get();
if (!IsInput(consumer, value)) {
return absl::InvalidArgumentError("Node is not a consumer of the value");
}
Erase(&n->inputs, value_ptr);
Erase(&v->consumers, node_ptr);
return absl::OkStatus();
}
absl::Status GraphFloat32::DeleteNode(NodeId id) {
NodeDef* n;
RETURN_IF_ERROR(LookupNode(id, &n));
Node* node_ptr = n->node.get();
for (auto value : n->inputs) {
Erase(&values_[value->id].consumers, node_ptr);
}
for (auto value : n->outputs) {
values_[value->id].producer = nullptr;
}
n->inputs.clear();
n->outputs.clear();
n->node.reset();
return absl::OkStatus();
}
absl::Status GraphFloat32::DeleteValue(ValueId id) {
ValueDef* v;
RETURN_IF_ERROR(LookupValue(id, &v));
Value* value_ptr = v->value.get();
if (v->producer != nullptr) {
Erase(&nodes_[v->producer->id].outputs, value_ptr);
}
if (!v->consumers.empty()) {
for (auto node : v->consumers) {
Erase(&nodes_[node->id].inputs, value_ptr);
}
}
v->producer = nullptr;
v->consumers.clear();
v->value.reset();
return absl::OkStatus();
}
absl::Status GraphFloat32::MakeExactCopy(GraphFloat32* model) const {
model->nodes_.clear();
model->execution_plan_.clear();
model->values_.clear();
for (auto& value_def : values_) {
model->values_.push_back({});
if (value_def.value) {
model->values_.back().value = absl::make_unique<Value>(*value_def.value);
}
}
// Add all nodes first.
for (auto node_id : execution_plan_) {
model->execution_plan_.push_back(node_id);
model->nodes_[node_id] = {};
auto& node_def = nodes_.at(node_id);
if (node_def.node) {
model->nodes_[node_id].node = absl::make_unique<Node>(*node_def.node);
}
}
// Wire up dependencies between nodes.
for (auto node_id : execution_plan_) {
auto& node_def = nodes_.at(node_id);
if (node_def.node) {
for (auto output : node_def.outputs) {
RETURN_IF_ERROR(model->SetProducer(node_def.node->id, output->id));
}
for (auto input : node_def.inputs) {
RETURN_IF_ERROR(model->AddConsumer(node_def.node->id, input->id));
}
}
}
return absl::OkStatus();
}
bool GraphFloat32::IsInput(NodeId node, ValueId value) {
if (node >= nodes_.size() || value >= values_.size()) {
return false;
}
const NodeDef& n = nodes_[node];
const ValueDef& v = values_[value];
if (!n.node || !v.value) {
return false;
}
return std::find(n.inputs.begin(), n.inputs.end(), v.value.get()) !=
n.inputs.end();
}
absl::Status GraphFloat32::LookupNode(NodeId id, NodeDef** node_def) {
if (id >= nodes_.size()) {
return absl::OutOfRangeError("NodeId is out of range");
}
auto& n = nodes_[id];
if (!n.node) {
return absl::OutOfRangeError("Node is already deleted");
}
*node_def = &n;
return absl::OkStatus();
}
absl::Status GraphFloat32::LookupValue(ValueId id, ValueDef** value_def) {
if (id >= values_.size()) {
return absl::OutOfRangeError("ValueId is out of range");
}
auto& v = values_[id];
if (!v.value) {
return absl::OutOfRangeError("Value is already deleted");
}
*value_def = &v;
return absl::OkStatus();
}
absl::Status RemovePrecedingNode(GraphFloat32* graph, const Node* to_remove,
const Node* to_keep) {
// Make sure all outputs from to_remove are consumed by to_keep.
for (auto output : graph->FindOutputs(to_remove->id)) {
auto consumers = graph->FindConsumers(output->id);
if (consumers.size() > 1 ||
(consumers.size() == 1 && consumers[0] != to_keep)) {
return absl::InvalidArgumentError(
"Output from to_remove node has other consumers");
}
}
// Update all references
for (auto input : graph->FindInputs(to_remove->id)) {
RETURN_IF_ERROR(graph->AddConsumer(to_keep->id, input->id));
}
for (auto output : graph->FindOutputs(to_remove->id)) {
RETURN_IF_ERROR(graph->DeleteValue(output->id));
}
return graph->DeleteNode(to_remove->id);
}
absl::Status RemoveFollowingNode(GraphFloat32* graph, const Node* to_remove,
const Node* to_keep) {
// Make sure all inputs to to_remove are produced by to_keep.
for (auto input : graph->FindInputs(to_remove->id)) {
Node* producer = graph->FindProducer(input->id);
if (producer->id != to_keep->id) {
return absl::InvalidArgumentError("To_remove node has other inputs");
}
}
for (auto input : graph->FindInputs(to_remove->id)) {
RETURN_IF_ERROR(graph->DeleteValue(input->id));
}
for (auto output : graph->FindOutputs(to_remove->id)) {
RETURN_IF_ERROR(graph->SetProducer(to_keep->id, output->id));
}
return graph->DeleteNode(to_remove->id);
}
absl::Status RemoveOneInputOneOutputNode(GraphFloat32* graph,
const Node* to_remove) {
auto inputs = graph->FindInputs(to_remove->id);
auto outputs = graph->FindOutputs(to_remove->id);
if (inputs.size() != 1 || outputs.size() != 1) {
return absl::InvalidArgumentError(
"To_remove node must have 1 input and 1 output");
}
auto input_id = inputs[0]->id;
auto output_id = outputs[0]->id;
Node* producer = graph->FindProducer(input_id);
auto consumers = graph->FindConsumers(output_id);
RETURN_IF_ERROR(graph->DeleteNode(to_remove->id));
for (auto& consumer : consumers) {
RETURN_IF_ERROR(graph->ReplaceInput(consumer->id, output_id, input_id));
}
RETURN_IF_ERROR(graph->DeleteValue(output_id));
if (!producer && consumers.empty()) {
RETURN_IF_ERROR(graph->DeleteValue(input_id));
}
return absl::OkStatus();
}
absl::Status AddOutput(GraphFloat32* graph, const Node* from_node,
Value** output) {
auto link = graph->NewValue();
RETURN_IF_ERROR(graph->SetProducer(from_node->id, link->id));
*output = link;
return absl::OkStatus();
}
absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node,
const Node* to_node, Value** output) {
Value* link;
RETURN_IF_ERROR(AddOutput(graph, from_node, &link));
RETURN_IF_ERROR(graph->AddConsumer(to_node->id, link->id));
*output = link;
return absl::OkStatus();
}
bool IsBatchMatchesForAllValues(const GraphFloat32& model) {
const int32_t b = model.values()[0]->tensor.shape.b;
for (auto value : model.values()) {
if (value->tensor.shape.b != b) {
return false;
}
}
return true;
}
} // namespace gpu
} // namespace tflite

View File

@ -50,77 +50,70 @@ struct QuantizationParams {
};
// Connects tensor's producer and operation that depends on this tensor.
template <typename TensorT>
struct Value {
using TensorType = TensorT;
const ValueId id;
TensorType tensor;
TensorRef<BHWC> tensor;
absl::optional<QuantizationParams> quant_params;
};
struct Operation {
std::string type;
absl::any attributes;
};
struct Node {
const NodeId id;
Operation operation;
};
// Graph is DAG that consists of nodes and values. Each value may have a single
// A DAG that consists of nodes and values. Each value may have a single
// producer node and multiple consumer nodes. Therefore, each node may have
// multiple input and output values.
//
// Value that does not have a producer is a graph's input. Value that does not
// have a consumer is a graph's output.
//
// Interface provides methods for graph introspection and manipulation. Abstract
// interface makes allows subgraphs representation to ensure safe manipulations.
template <typename TensorT>
class Graph {
// It keeps values and nodes referenced by their index in a vector. Therefore,
// nodes and values are never deleted, but rather erased, where corresponding
// index remains.
//
// It is possible to re-use removed indices, but it is not implemented yet.
class GraphFloat32 {
public:
virtual ~Graph() = default;
// @return a collection of nodes in this graph.
virtual std::vector<Node*> nodes() const = 0;
std::vector<Node*> nodes() const;
// @return a collection of values in this graph.
virtual std::vector<Value<TensorT>*> values() const = 0;
std::vector<Value*> values() const;
// @return graph inputs, that are values without producers.
virtual std::vector<Value<TensorT>*> inputs() const = 0;
std::vector<Value*> inputs() const;
// @return graph outputs, that are values without consumers.
virtual std::vector<Value<TensorT>*> outputs() const = 0;
std::vector<Value*> outputs() const;
// @return inputs into the given node. Returns empty vector for deleted node.
virtual std::vector<Value<TensorT>*> FindInputs(NodeId id) const = 0;
std::vector<Value*> FindInputs(NodeId id) const;
// @return outputs from the given node. Returns empty vector for deleted node.
virtual std::vector<Value<TensorT>*> FindOutputs(NodeId id) const = 0;
std::vector<Value*> FindOutputs(NodeId id) const;
virtual bool IsGraphInput(ValueId id) const = 0;
bool IsGraphInput(ValueId id) const;
virtual bool IsGraphOutput(ValueId id) const = 0;
bool IsGraphOutput(ValueId id) const;
// @return producer of the given value. Returns nullptr for deleted value.
virtual Node* FindProducer(ValueId id) const = 0;
Node* FindProducer(ValueId id) const;
// @return consumers of the given value. Returns empty vector for deleted
// value.
virtual std::vector<Node*> FindConsumers(ValueId id) const = 0;
std::vector<Node*> FindConsumers(ValueId id) const;
// @return a node or nullptr if node with the given id is not present.
virtual Node* GetNode(NodeId id) const = 0;
Node* GetNode(NodeId id) const;
// @return a value or nullptr if value with the given id is not present.
virtual Value<TensorT>* GetValue(ValueId id) const = 0;
Value* GetValue(ValueId id) const;
//////////////////////////////////////////////////////////////////////////////
// Graph manipulation functions are below
@ -129,386 +122,61 @@ class Graph {
// @return new node created in this graph
// NOTE: nodes should be created in the topological order, e.g. node A that
// depends on a value from node B should be created after node B.
virtual Node* NewNode() = 0;
Node* NewNode();
// Insert Node after another in the execution plan.
absl::Status InsertNodeAfter(NodeId id, Node** new_node);
// @return new value created in this graph
virtual Value<TensorT>* NewValue() = 0;
Value* NewValue();
// Sets a producer for the given value. There could be a single producer
// for a value. If a value had another producer, it will reassign producer
// appropriately. If a value didn't have a producer, it will be removed
// from a graph's input.
virtual absl::Status SetProducer(NodeId producer, ValueId value) = 0;
absl::Status SetProducer(NodeId producer, ValueId value);
// Removes a producer for the given value. Value becomes producer-less and
// therefore becomes graph's input.
virtual absl::Status RemoveProducer(ValueId value) = 0;
absl::Status RemoveProducer(ValueId value);
// Sets a consumer for the given value. There could be multiple consumers
// for a value.
virtual absl::Status AddConsumer(NodeId consumer, ValueId value) = 0;
absl::Status AddConsumer(NodeId consumer, ValueId value);
// Replace input value for given node.
virtual absl::Status ReplaceInput(NodeId node, ValueId old_value,
ValueId new_value) = 0;
absl::Status ReplaceInput(NodeId node, ValueId old_value, ValueId new_value);
// Removes a consumer for the given value. If value does not have any
// consumers it becomes graph's output.
virtual absl::Status RemoveConsumer(NodeId consumer, ValueId value) = 0;
absl::Status RemoveConsumer(NodeId consumer, ValueId value);
// Removes node from this graph. For all input values this node will be
// removed from consumers and for all output values a producer will be
// removed.
virtual absl::Status DeleteNode(NodeId id) = 0;
absl::Status DeleteNode(NodeId id);
// Removes value from this graph. It will be removed from inputs for all
// dependent nodes. A node that was a producer of this value will loose its
// output.
virtual absl::Status DeleteValue(ValueId id) = 0;
};
absl::Status DeleteValue(ValueId id);
// Implementation of a Graph interface. It keeps values and nodes referenced by
// their index in a vector. Therefore, nodes and values are never deleted, but
// rather erased, where corresponding index remains.
//
// It is possible to re-use removed indices, but it is not implemented yet.
template <typename TensorT>
class Model : public Graph<TensorT> {
public:
const std::string& name() const { return name_; }
void set_name(std::string name) { name_ = std::move(name); }
std::vector<Value<TensorT>*> values() const final {
return FilterValues([](const ValueDef&) { return true; });
}
// Returns nodes in the execution order.
std::vector<Node*> nodes() const final {
return FilterNodes([](const NodeDef&) { return true; });
}
std::vector<Value<TensorT>*> inputs() const final {
return FilterValues(
[](const ValueDef& v) { return v.producer == nullptr; });
}
std::vector<Value<TensorT>*> outputs() const final {
return FilterValues([](const ValueDef& v) { return v.consumers.empty(); });
}
bool IsGraphInput(ValueId id) const final {
if (id >= values_.size()) {
return false;
}
return values_[id].producer == nullptr;
}
bool IsGraphOutput(ValueId id) const final {
if (id >= values_.size()) {
return false;
}
return values_[id].consumers.empty();
}
Node* GetNode(NodeId id) const final {
if (id >= nodes_.size()) {
return {};
}
return nodes_.at(id).node.get();
}
Value<TensorT>* GetValue(ValueId id) const final {
if (id >= values_.size()) {
return nullptr;
}
return values_[id].value.get();
}
// Append Node to the end of the execution plan.
Node* NewNode() final {
const NodeId new_id = nodes_.size();
NodeDef def;
def.node = absl::make_unique<Node>(Node{static_cast<NodeId>(new_id), {}});
Node* node = def.node.get();
nodes_[new_id] = std::move(def);
execution_plan_.push_back(new_id);
return node;
}
// Insert Node after another in the execution plan.
absl::Status InsertNodeAfter(NodeId id, Node** new_node) {
if (id >= nodes_.size()) {
return absl::OutOfRangeError("NodeId is out of range");
}
int idx = 0;
while (idx < execution_plan_.size()) {
if (execution_plan_[idx] == id) break;
++idx;
}
if (idx == execution_plan_.size()) {
return absl::OutOfRangeError("NodeId not in execution plan");
}
const NodeId new_id = nodes_.size();
NodeDef def;
def.node = absl::make_unique<Node>(Node{static_cast<NodeId>(new_id), {}});
*new_node = def.node.get();
nodes_[new_id] = std::move(def);
execution_plan_.insert(execution_plan_.begin() + idx + 1, new_id);
return absl::OkStatus();
}
Value<TensorT>* NewValue() final {
ValueDef def;
def.value = absl::make_unique<Value<TensorT>>(
Value<TensorT>{static_cast<ValueId>(values_.size()), {}});
Value<TensorT>* value = def.value.get();
values_.push_back(std::move(def));
return value;
}
std::vector<Value<TensorT>*> FindInputs(NodeId id) const final {
if (id >= nodes_.size()) {
return {};
}
return nodes_.at(id).inputs;
}
std::vector<Value<TensorT>*> FindOutputs(NodeId id) const final {
if (id >= nodes_.size()) {
return {};
}
return nodes_.at(id).outputs;
}
Node* FindProducer(ValueId id) const final {
if (id >= values_.size()) {
return nullptr;
}
return values_[id].producer;
}
std::vector<Node*> FindConsumers(ValueId id) const final {
if (id >= values_.size()) {
return {};
}
return values_[id].consumers;
}
absl::Status SetProducer(NodeId producer, ValueId value) final {
ValueDef* v;
RETURN_IF_ERROR(LookupValue(value, &v));
Value<TensorT>* value_ptr = v->value.get();
NodeDef* n;
RETURN_IF_ERROR(LookupNode(producer, &n));
Node* node_ptr = n->node.get();
// check if this value has the same producer already
if (node_ptr == v->producer) {
return absl::AlreadyExistsError(absl::StrCat(
"Node ", producer, " is already a producer of the value ", value));
}
// Check if the node is a consumer of this value.
if (IsInput(producer, value)) {
return absl::InvalidArgumentError("Node is a consumer of the value");
}
// TODO(akulik): detect circular dependency?
if (v->producer != nullptr) {
// value is no longer produced by it's previous producer.
Erase(&nodes_[v->producer->id].outputs, value_ptr);
}
v->producer = node_ptr;
n->outputs.push_back(value_ptr);
return absl::OkStatus();
}
absl::Status RemoveProducer(ValueId value) final {
ValueDef* v;
RETURN_IF_ERROR(LookupValue(value, &v));
Value<TensorT>* value_ptr = v->value.get();
if (v->producer == nullptr) {
return absl::InvalidArgumentError("Value does not have a producer");
}
Erase(&nodes_[v->producer->id].outputs, value_ptr);
v->producer = nullptr;
return absl::OkStatus();
}
absl::Status ReplaceInput(NodeId node, ValueId old_value,
ValueId new_value) final {
ValueDef* v_old;
RETURN_IF_ERROR(LookupValue(old_value, &v_old));
Value<TensorT>* value_old_ptr = v_old->value.get();
ValueDef* v_new;
RETURN_IF_ERROR(LookupValue(new_value, &v_new));
Value<TensorT>* value_new_ptr = v_new->value.get();
NodeDef* n;
RETURN_IF_ERROR(LookupNode(node, &n));
Node* node_ptr = n->node.get();
// Check if the node is a consumer of old_value.
if (!IsInput(node, old_value)) {
return absl::InvalidArgumentError("old_value must be input of node.");
}
// Check if the node is not a consumer of new_value.
if (IsInput(node, new_value)) {
return absl::InvalidArgumentError("new_value can not be input of node.");
}
// Check if this value has the same producer already
if (node_ptr == v_new->producer) {
return absl::InvalidArgumentError("new_value can not be output of node.");
}
for (int i = 0; i < n->inputs.size(); ++i) {
if (n->inputs[i] == value_old_ptr) {
n->inputs[i] = value_new_ptr;
break;
}
}
v_new->consumers.push_back(node_ptr);
Erase(&v_old->consumers, node_ptr);
return absl::OkStatus();
}
absl::Status AddConsumer(NodeId consumer, ValueId value) final {
ValueDef* v;
RETURN_IF_ERROR(LookupValue(value, &v));
Value<TensorT>* value_ptr = v->value.get();
NodeDef* n;
RETURN_IF_ERROR(LookupNode(consumer, &n));
Node* node_ptr = n->node.get();
// check if this value has the same producer already
if (node_ptr == v->producer) {
return absl::InvalidArgumentError("Node is a producer of the value");
}
// check if this value has the same consumer already
if (IsInput(consumer, value)) {
return absl::AlreadyExistsError(absl::StrCat(
"Node ", consumer, " is already a consumer of the value ", value));
}
n->inputs.push_back(value_ptr);
v->consumers.push_back(node_ptr);
return absl::OkStatus();
}
absl::Status RemoveConsumer(NodeId consumer, ValueId value) final {
ValueDef* v;
RETURN_IF_ERROR(LookupValue(value, &v));
Value<TensorT>* value_ptr = v->value.get();
NodeDef* n;
RETURN_IF_ERROR(LookupNode(consumer, &n));
Node* node_ptr = n->node.get();
if (!IsInput(consumer, value)) {
return absl::InvalidArgumentError("Node is not a consumer of the value");
}
Erase(&n->inputs, value_ptr);
Erase(&v->consumers, node_ptr);
return absl::OkStatus();
}
absl::Status DeleteNode(NodeId id) final {
NodeDef* n;
RETURN_IF_ERROR(LookupNode(id, &n));
Node* node_ptr = n->node.get();
for (auto value : n->inputs) {
Erase(&values_[value->id].consumers, node_ptr);
}
for (auto value : n->outputs) {
values_[value->id].producer = nullptr;
}
n->inputs.clear();
n->outputs.clear();
n->node.reset();
return absl::OkStatus();
}
absl::Status DeleteValue(ValueId id) final {
ValueDef* v;
RETURN_IF_ERROR(LookupValue(id, &v));
Value<TensorT>* value_ptr = v->value.get();
if (v->producer != nullptr) {
Erase(&nodes_[v->producer->id].outputs, value_ptr);
}
if (!v->consumers.empty()) {
for (auto node : v->consumers) {
Erase(&nodes_[node->id].inputs, value_ptr);
}
}
v->producer = nullptr;
v->consumers.clear();
v->value.reset();
return absl::OkStatus();
}
absl::Status MakeExactCopy(Model<TensorT>* model) const {
model->nodes_.clear();
model->execution_plan_.clear();
model->values_.clear();
model->name_ = name_;
for (auto& value_def : values_) {
model->values_.push_back({});
if (value_def.value) {
model->values_.back().value =
absl::make_unique<Value<TensorT>>(*value_def.value);
}
}
// Add all nodes first.
for (auto node_id : execution_plan_) {
model->execution_plan_.push_back(node_id);
model->nodes_[node_id] = {};
auto& node_def = nodes_.at(node_id);
if (node_def.node) {
model->nodes_[node_id].node = absl::make_unique<Node>(*node_def.node);
}
}
// Wire up dependencies between nodes.
for (auto node_id : execution_plan_) {
auto& node_def = nodes_.at(node_id);
if (node_def.node) {
for (auto output : node_def.outputs) {
RETURN_IF_ERROR(model->SetProducer(node_def.node->id, output->id));
}
for (auto input : node_def.inputs) {
RETURN_IF_ERROR(model->AddConsumer(node_def.node->id, input->id));
}
}
}
return absl::OkStatus();
}
absl::Status MakeExactCopy(GraphFloat32* model) const;
private:
struct NodeDef {
std::vector<Value<TensorT>*> inputs;
std::vector<Value<TensorT>*> outputs;
std::vector<Value*> inputs;
std::vector<Value*> outputs;
std::unique_ptr<Node> node;
};
struct ValueDef {
Node* producer = nullptr;
std::vector<Node*> consumers;
std::unique_ptr<Value<TensorT>> value;
std::unique_ptr<Value> value;
};
bool IsInput(NodeId node, ValueId value) {
if (node >= nodes_.size() || value >= values_.size()) {
return false;
}
const NodeDef& n = nodes_[node];
const ValueDef& v = values_[value];
if (!n.node || !v.value) {
return false;
}
return std::find(n.inputs.begin(), n.inputs.end(), v.value.get()) !=
n.inputs.end();
}
bool IsInput(NodeId node, ValueId value);
template <typename T>
static void Erase(std::vector<T>* values, T value) {
@ -516,34 +184,14 @@ class Model : public Graph<TensorT> {
}
// @return non-nullptr NodeDef that has valid Node or an error
absl::Status LookupNode(NodeId id, NodeDef** node_def) {
if (id >= nodes_.size()) {
return absl::OutOfRangeError("NodeId is out of range");
}
auto& n = nodes_[id];
if (!n.node) {
return absl::OutOfRangeError("Node is already deleted");
}
*node_def = &n;
return absl::OkStatus();
}
absl::Status LookupNode(NodeId id, NodeDef** node_def);
// @return non-nullptr ValueDef that has valid Value or an error
absl::Status LookupValue(ValueId id, ValueDef** value_def) {
if (id >= values_.size()) {
return absl::OutOfRangeError("ValueId is out of range");
}
auto& v = values_[id];
if (!v.value) {
return absl::OutOfRangeError("Value is already deleted");
}
*value_def = &v;
return absl::OkStatus();
}
absl::Status LookupValue(ValueId id, ValueDef** value_def);
template <typename Pred>
std::vector<Value<TensorT>*> FilterValues(const Pred& predicate) const {
std::vector<Value<TensorT>*> values;
std::vector<Value*> FilterValues(const Pred& predicate) const {
std::vector<Value*> values;
values.reserve(values_.size());
for (auto& v : values_) {
if (v.value != nullptr && predicate(v)) {
@ -566,8 +214,6 @@ class Model : public Graph<TensorT> {
return nodes;
}
std::string name_;
// There are two approaches possible: wrap entire NodeDef and ValueDef into
// unique_ptr and store it in values_ and nodes_ or store it by value.
// We store it by value here to make introspection calls cheaper.
@ -581,108 +227,27 @@ class Model : public Graph<TensorT> {
// Removes to_remove node that precedes to_keep node only if to_remove has
// outputs that are consumed only by to_keep. In such case to_keep inherits all
// to_remove inputs.
template <typename TensorT>
absl::Status RemovePrecedingNode(Graph<TensorT>* graph, const Node* to_remove,
const Node* to_keep) {
// Make sure all outputs from to_remove are consumed by to_keep.
for (auto output : graph->FindOutputs(to_remove->id)) {
auto consumers = graph->FindConsumers(output->id);
if (consumers.size() > 1 ||
(consumers.size() == 1 && consumers[0] != to_keep)) {
return absl::InvalidArgumentError(
"Output from to_remove node has other consumers");
}
}
// Update all references
for (auto input : graph->FindInputs(to_remove->id)) {
RETURN_IF_ERROR(graph->AddConsumer(to_keep->id, input->id));
}
for (auto output : graph->FindOutputs(to_remove->id)) {
RETURN_IF_ERROR(graph->DeleteValue(output->id));
}
return graph->DeleteNode(to_remove->id);
}
absl::Status RemovePrecedingNode(GraphFloat32* graph, const Node* to_remove,
const Node* to_keep);
// Removes to_remove node that follows to_keep node only if to_remove has inputs
// that are produced by to_keep. to_keep inherits all to_remove inputs.
template <typename TensorT>
absl::Status RemoveFollowingNode(Graph<TensorT>* graph, const Node* to_remove,
const Node* to_keep) {
// Make sure all inputs to to_remove are produced by to_keep.
for (auto input : graph->FindInputs(to_remove->id)) {
Node* producer = graph->FindProducer(input->id);
if (producer->id != to_keep->id) {
return absl::InvalidArgumentError("To_remove node has other inputs");
}
}
for (auto input : graph->FindInputs(to_remove->id)) {
RETURN_IF_ERROR(graph->DeleteValue(input->id));
}
for (auto output : graph->FindOutputs(to_remove->id)) {
RETURN_IF_ERROR(graph->SetProducer(to_keep->id, output->id));
}
return graph->DeleteNode(to_remove->id);
}
absl::Status RemoveFollowingNode(GraphFloat32* graph, const Node* to_remove,
const Node* to_keep);
// Removes to_remove node.
// Requires that node has one input and one output;
template <typename TensorT>
absl::Status RemoveOneInputOneOutputNode(Graph<TensorT>* graph,
const Node* to_remove) {
auto inputs = graph->FindInputs(to_remove->id);
auto outputs = graph->FindOutputs(to_remove->id);
if (inputs.size() != 1 || outputs.size() != 1) {
return absl::InvalidArgumentError(
"To_remove node must have 1 input and 1 output");
}
auto input_id = inputs[0]->id;
auto output_id = outputs[0]->id;
Node* producer = graph->FindProducer(input_id);
auto consumers = graph->FindConsumers(output_id);
RETURN_IF_ERROR(graph->DeleteNode(to_remove->id));
for (auto& consumer : consumers) {
RETURN_IF_ERROR(graph->ReplaceInput(consumer->id, output_id, input_id));
}
RETURN_IF_ERROR(graph->DeleteValue(output_id));
if (!producer && consumers.empty()) {
RETURN_IF_ERROR(graph->DeleteValue(input_id));
}
return absl::OkStatus();
}
absl::Status RemoveOneInputOneOutputNode(GraphFloat32* graph,
const Node* to_remove);
template <typename TensorT>
absl::Status AddOutput(Graph<TensorT>* graph, const Node* from_node,
Value<TensorT>** output) {
auto link = graph->NewValue();
RETURN_IF_ERROR(graph->SetProducer(from_node->id, link->id));
*output = link;
return absl::OkStatus();
}
absl::Status AddOutput(GraphFloat32* graph, const Node* from_node,
Value** output);
template <typename TensorT>
absl::Status ConnectTwoNodes(Graph<TensorT>* graph, const Node* from_node,
const Node* to_node, Value<TensorT>** output) {
Value<TensorT>* link;
RETURN_IF_ERROR(AddOutput(graph, from_node, &link));
RETURN_IF_ERROR(graph->AddConsumer(to_node->id, link->id));
*output = link;
return absl::OkStatus();
}
using GraphFloat32 = Model<TensorRef<BHWC>>;
absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node,
const Node* to_node, Value** output);
// @return true if all tensors have same batch value.
inline bool IsBatchMatchesForAllValues(const GraphFloat32& model) {
const int32_t b = model.values()[0]->tensor.shape.b;
for (auto value : model.values()) {
if (value->tensor.shape.b != b) {
return false;
}
}
return true;
}
bool IsBatchMatchesForAllValues(const GraphFloat32& model);
} // namespace gpu
} // namespace tflite

View File

@ -66,12 +66,11 @@ namespace {
// will turn into:
// node(copy(output)) <- passthrough_node(output)
absl::Status NewPassthroughNode(GraphFloat32* graph, Node* node,
const Value<TensorRef<BHWC>>* output,
Node** passthru_node) {
const Value* output, Node** passthru_node) {
*passthru_node = graph->NewNode();
// Make copies for every output in the original node.
RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id));
Value<TensorRef<BHWC>>* copy_output = graph->NewValue();
Value* copy_output = graph->NewValue();
RETURN_IF_ERROR(graph->SetProducer(node->id, copy_output->id));
RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id));
copy_output->tensor = output->tensor;
@ -323,8 +322,7 @@ absl::Status CheckKernelsAndStrides(int kernel_h, int kernel_w, int strides_h,
}
// Creates a simple node that holds tensor value.
absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph,
Value<TensorRef<BHWC>>** value) {
absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, Value** value) {
ConstTensorAttributes attr;
attr.tensor = std::move(t);
Node* node = graph->NewNode();
@ -457,16 +455,16 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
ConcatAttributes attr;
// Read inputs first to make sure const node is added to a graph before
// concat node to ensure topological order.
std::vector<const Value<TensorRef<BHWC>>*> inputs;
std::vector<const Value*> inputs;
for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
Value<TensorRef<BHWC>>* value;
Value* value;
const auto status = reader->ReadValue(idx, &value);
if (status.ok()) {
inputs.push_back(value);
} else {
TensorFloat32 tensor;
RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
Value<TensorRef<BHWC>>* value;
Value* value;
RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
inputs.push_back(value);
}
@ -475,7 +473,7 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONCAT);
RETURN_IF_ERROR(reader->AddOutputs(node));
for (const Value<TensorRef<BHWC>>* input : inputs) {
for (const Value* input : inputs) {
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
}
@ -1011,7 +1009,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser {
if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
auto& reshape = node;
conv = graph->NewNode(); // reset conv pointer!
Value<TensorRef<BHWC>>* reshaped_value = graph->NewValue();
Value* reshaped_value = graph->NewValue();
reshaped_value->tensor.type = DataType::FLOAT32;
reshaped_value->tensor.shape =
BHWC(input->tensor.shape.b, 1, 1, weights.shape.w);
@ -1121,11 +1119,11 @@ class LSTMOperationParser : public TFLiteOperationParser {
lstm_attr.kernel_type = LstmKernelType::BASIC;
lstm_node->operation.attributes = lstm_attr;
Value<TensorRef<BHWC>>* concat_temp;
Value* concat_temp;
int concat_tensor_idx = tflite_node->outputs->data[2];
RETURN_IF_ERROR(
reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
Value<TensorRef<BHWC>>* activ_temp;
Value* activ_temp;
int activ_tensor_idx = tflite_node->outputs->data[3];
RETURN_IF_ERROR(
reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
@ -1677,7 +1675,7 @@ class SliceOperationParser : public TFLiteOperationParser {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SLICE);
RETURN_IF_ERROR(reader->AddOutputs(node));
Value<TensorRef<BHWC>>* input;
Value* input;
RETURN_IF_ERROR(reader->ReadValue(0, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
@ -1842,7 +1840,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SLICE);
RETURN_IF_ERROR(reader->AddOutputs(node));
Value<TensorRef<BHWC>>* input;
Value* input;
RETURN_IF_ERROR(reader->ReadValue(0, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
@ -2027,7 +2025,7 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
GraphFloat32* graph, ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
Value<TensorRef<BHWC>>* input;
Value* input;
RETURN_IF_ERROR(reader->ReadValue(2, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
RETURN_IF_ERROR(reader->AddOutputs(node));
@ -2677,7 +2675,7 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops) {
absl::Status PrecreateIOTensors(
TfLiteContext* context, GraphFloat32* graph, TfLiteIntArray* io_tensors,
std::unordered_map<int, int>* quant_conversion_map,
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value) {
std::unordered_map<int, Value*>* tensor_to_value) {
for (int i = 0; i < io_tensors->size; ++i) {
const int tensor_index = io_tensors->data[i];
const TfLiteTensor& tflite_tensor = context->tensors[tensor_index];
@ -2717,7 +2715,7 @@ absl::Status BuildModel(TfLiteContext* context,
operations.push_back(std::move(op_parser));
tflite_nodes.push_back(i);
}
std::unordered_map<int, Value<TensorRef<BHWC>>*> tensor_to_value;
std::unordered_map<int, Value*> tensor_to_value;
RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
delegate_params->input_tensors,
quant_conversion_map, &tensor_to_value));

View File

@ -32,8 +32,8 @@ TEST(Model, SingleNode) {
// graph_input -> node -> graph_output
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value* graph_input = graph.NewValue();
Value* graph_output = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
@ -53,9 +53,9 @@ TEST(Model, SingleNodeMultipleOutputs) {
// graph_input -> node -> (graph_output1, graph_output2)
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output1 = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output2 = graph.NewValue();
Value* graph_input = graph.NewValue();
Value* graph_output1 = graph.NewValue();
Value* graph_output2 = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node->id, graph_output1->id).ok());
ASSERT_TRUE(graph.SetProducer(node->id, graph_output2->id).ok());
@ -68,7 +68,7 @@ TEST(Model, SingleNodeMultipleOutputs) {
TEST(Model, SetSameConsumer) {
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value* graph_input = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
EXPECT_FALSE(graph.AddConsumer(node->id, graph_input->id).ok());
}
@ -77,8 +77,8 @@ TEST(Model, RemoveConsumer) {
// (graph_input1, graph_input2) -> node
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input1 = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input2 = graph.NewValue();
Value* graph_input1 = graph.NewValue();
Value* graph_input2 = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input1->id).ok());
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input2->id).ok());
EXPECT_THAT(graph.FindConsumers(graph_input1->id),
@ -102,7 +102,7 @@ TEST(Model, RemoveConsumer) {
TEST(Model, SetSameProducer) {
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value* graph_output = graph.NewValue();
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
EXPECT_FALSE(graph.SetProducer(node->id, graph_output->id).ok());
}
@ -110,10 +110,10 @@ TEST(Model, SetSameProducer) {
TEST(Model, ReplaceInput) {
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRef<BHWC>>* v0 = graph.NewValue();
Value<TensorRef<BHWC>>* v1 = graph.NewValue();
Value<TensorRef<BHWC>>* v2 = graph.NewValue();
Value<TensorRef<BHWC>>* v3 = graph.NewValue();
Value* v0 = graph.NewValue();
Value* v1 = graph.NewValue();
Value* v2 = graph.NewValue();
Value* v3 = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, v0->id).ok());
ASSERT_TRUE(graph.AddConsumer(node->id, v1->id).ok());
ASSERT_TRUE(graph.AddConsumer(node->id, v2->id).ok());
@ -125,7 +125,7 @@ TEST(Model, ReplaceInput) {
TEST(Model, RemoveProducer) {
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value* graph_output = graph.NewValue();
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
EXPECT_THAT(graph.inputs(), UnorderedElementsAre());
@ -142,8 +142,8 @@ TEST(Model, RemoveProducer) {
TEST(Model, RemoveSimpleNodeDegenerateCase) {
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value* graph_input = graph.NewValue();
Value* graph_output = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
@ -161,9 +161,9 @@ TEST(Model, RemoveSimpleNodeNoPreviousNode) {
GraphFloat32 graph;
Node* simple_node = graph.NewNode();
Node* consumer_node = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* value = graph.NewValue();
Value* graph_input = graph.NewValue();
Value* graph_output = graph.NewValue();
Value* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(simple_node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(simple_node->id, value->id).ok());
@ -183,9 +183,9 @@ TEST(Model, RemoveSimpleNodeNoAfterNodes) {
GraphFloat32 graph;
Node* simple_node = graph.NewNode();
Node* producer_node = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* value = graph.NewValue();
Value* graph_input = graph.NewValue();
Value* graph_output = graph.NewValue();
Value* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(simple_node->id, value->id).ok());
ASSERT_TRUE(graph.SetProducer(simple_node->id, graph_output->id).ok());
@ -206,10 +206,10 @@ TEST(Model, RemoveSimpleNodeGeneralCase) {
Node* simple_node = graph.NewNode();
Node* producer_node = graph.NewNode();
Node* consumer_node = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* value0 = graph.NewValue();
Value<TensorRef<BHWC>>* value1 = graph.NewValue();
Value* graph_input = graph.NewValue();
Value* graph_output = graph.NewValue();
Value* value0 = graph.NewValue();
Value* value1 = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(producer_node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(producer_node->id, value0->id).ok());
@ -257,12 +257,12 @@ TEST(Model, RemoveSimpleNodeComplexCase) {
Node* n0 = graph.NewNode();
Node* n1 = graph.NewNode(); // node to remove
Node* n2 = graph.NewNode();
Value<TensorRef<BHWC>>* v0 = graph.NewValue();
Value<TensorRef<BHWC>>* v1 = graph.NewValue();
Value<TensorRef<BHWC>>* v2 = graph.NewValue(); // value to be removed
Value<TensorRef<BHWC>>* v3 = graph.NewValue();
Value<TensorRef<BHWC>>* o1 = graph.NewValue();
Value<TensorRef<BHWC>>* o2 = graph.NewValue();
Value* v0 = graph.NewValue();
Value* v1 = graph.NewValue();
Value* v2 = graph.NewValue(); // value to be removed
Value* v3 = graph.NewValue();
Value* o1 = graph.NewValue();
Value* o2 = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(n0->id, v0->id).ok());
ASSERT_TRUE(graph.AddConsumer(n0->id, v1->id).ok());
@ -289,14 +289,14 @@ TEST(Model, CircularDependency) {
{
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRef<BHWC>>* value = graph.NewValue();
Value* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, value->id).ok());
EXPECT_FALSE(graph.SetProducer(node->id, value->id).ok());
}
{
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRef<BHWC>>* value = graph.NewValue();
Value* value = graph.NewValue();
ASSERT_TRUE(graph.SetProducer(node->id, value->id).ok());
EXPECT_FALSE(graph.AddConsumer(node->id, value->id).ok());
}
@ -309,8 +309,8 @@ TEST(Model, ReassignValue) {
GraphFloat32 graph;
Node* node1 = graph.NewNode();
Node* node2 = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value* graph_input = graph.NewValue();
Value* graph_output = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node1->id, graph_output->id).ok());
ASSERT_TRUE(graph.AddConsumer(node2->id, graph_input->id).ok());
@ -336,9 +336,9 @@ TEST(Model, DeleteValue) {
GraphFloat32 graph;
Node* node1 = graph.NewNode();
Node* node2 = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* value = graph.NewValue();
Value* graph_input = graph.NewValue();
Value* graph_output = graph.NewValue();
Value* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
@ -377,10 +377,10 @@ TEST(Model, DeleteNode) {
Node* node1 = graph.NewNode();
Node* node2 = graph.NewNode();
Node* node3 = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output2 = graph.NewValue();
Value<TensorRef<BHWC>>* value = graph.NewValue();
Value* graph_input = graph.NewValue();
Value* graph_output = graph.NewValue();
Value* graph_output2 = graph.NewValue();
Value* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
@ -437,9 +437,9 @@ TEST(Model, InsertNodeAfter) {
GraphFloat32 graph;
Node* node1 = graph.NewNode();
Node* node2 = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* value = graph.NewValue();
Value* graph_input = graph.NewValue();
Value* graph_output = graph.NewValue();
Value* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());

View File

@ -28,10 +28,9 @@ namespace tflite {
namespace gpu {
absl::Status ObjectReader::ReadNonConstantTensor(
TfLiteContext* context,
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value,
TfLiteContext* context, std::unordered_map<int, Value*>* tensor_to_value,
std::unordered_map<int, int>* quant_conversion_map, GraphFloat32* graph,
uint32_t tensor_idx, Value<TensorRef<BHWC>>** value) {
uint32_t tensor_idx, Value** value) {
if (tensor_idx >= context->tensors_size) {
return absl::OutOfRangeError(
absl::StrCat("ReadNonConstTensor: input tensor index: ", tensor_idx));
@ -63,7 +62,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
(*quant_conversion_map)[fp_tensor_index] = tensor_idx;
(*quant_conversion_map)[tensor_idx] = fp_tensor_index;
// Add a new GPU Value for the new dequantized floating-point tensor.
Value<TensorRef<BHWC>>* value = graph->NewValue();
Value* value = graph->NewValue();
RETURN_IF_ERROR(
ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, &value->tensor));
value->tensor.ref = fp_tensor_index;
@ -77,7 +76,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
tensor_idx = quant_conversion_map->at(tensor_idx);
} else {
// Floating-point case.
Value<TensorRef<BHWC>>* value = graph->NewValue();
Value* value = graph->NewValue();
RETURN_IF_ERROR(
ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor));
value->tensor.ref = tensor_idx;
@ -91,8 +90,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
return absl::OkStatus();
}
absl::Status ObjectReader::ReadValue(uint32_t idx,
Value<TensorRef<BHWC>>** value) {
absl::Status ObjectReader::ReadValue(uint32_t idx, Value** value) {
if (idx >= node_->inputs->size) {
return absl::OutOfRangeError(
absl::StrCat("ReadValue: input tensor index: ", idx));
@ -100,8 +98,8 @@ absl::Status ObjectReader::ReadValue(uint32_t idx,
return ReadValueByTensorIdx(node_->inputs->data[idx], value);
}
absl::Status ObjectReader::ReadValueByTensorIdx(
uint32_t tensor_idx, Value<TensorRef<BHWC>>** value) {
absl::Status ObjectReader::ReadValueByTensorIdx(uint32_t tensor_idx,
Value** value) {
// Constant tensors should be handled by ReadTensor.
return ReadNonConstantTensor(context_, tensor_to_value_,
quant_conversion_map_, graph_, tensor_idx,
@ -133,7 +131,7 @@ absl::Status ObjectReader::AddOutput(const Node* node, int id) {
node_->outputs->size));
}
int output_tensor_idx = node_->outputs->data[id];
Value<TensorRef<BHWC>>* value;
Value* value;
RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value));
RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id));
return absl::OkStatus();
@ -147,7 +145,7 @@ absl::Status ObjectReader::AddOutputs(const Node* node) {
}
absl::Status ObjectReader::AddInput(const Node* node, uint32_t idx) {
Value<TensorRef<BHWC>>* input;
Value* input;
RETURN_IF_ERROR(ReadValue(idx, &input));
return graph_->AddConsumer(node->id, input->id);
}

View File

@ -34,25 +34,23 @@ namespace gpu {
class ObjectReader {
public:
static absl::Status ReadNonConstantTensor(
TfLiteContext* context,
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value,
TfLiteContext* context, std::unordered_map<int, Value*>* tensor_to_value,
std::unordered_map<int, int>* quant_conversion_map, GraphFloat32* graph,
uint32_t tensor_idx, Value<TensorRef<BHWC>>** value = nullptr);
uint32_t tensor_idx, Value** value = nullptr);
ObjectReader(
GraphFloat32* graph, TfLiteContext* context, const TfLiteNode* node,
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value,
std::unordered_map<int, int>* quant_conversion_map = nullptr)
ObjectReader(GraphFloat32* graph, TfLiteContext* context,
const TfLiteNode* node,
std::unordered_map<int, Value*>* tensor_to_value,
std::unordered_map<int, int>* quant_conversion_map = nullptr)
: graph_(graph),
context_(context),
node_(node),
tensor_to_value_(tensor_to_value),
quant_conversion_map_(quant_conversion_map) {}
absl::Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value);
absl::Status ReadValue(uint32_t idx, Value** value);
absl::Status ReadValueByTensorIdx(uint32_t tensor_idx,
Value<TensorRef<BHWC>>** value);
absl::Status ReadValueByTensorIdx(uint32_t tensor_idx, Value** value);
int GetNumberOfRuntimeInputs() const;
@ -89,7 +87,7 @@ class ObjectReader {
GraphFloat32* graph_;
TfLiteContext* context_;
const TfLiteNode* node_;
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value_;
std::unordered_map<int, Value*>* tensor_to_value_;
std::unordered_map<int, int>* quant_conversion_map_;
};

View File

@ -64,7 +64,7 @@ class AddQuantAdjustments : public NodeTransformation {
// Add one output Value for the new node.
// The tensor information should rename the same.
Value<TensorRef<BHWC>>* adjusted_value = graph->NewValue();
Value* adjusted_value = graph->NewValue();
adjusted_value->tensor = output_value->tensor;
status =
graph->SetProducer(quant_and_dequant_node->id, adjusted_value->id);

View File

@ -59,7 +59,7 @@ TEST(AddQuantAdjustments, OneNode) {
ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
Value<TensorRef<BHWC>>* output;
Value* output;
AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/2.0,
/*scale=*/0.008);
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
@ -114,18 +114,18 @@ TEST(AddQuantAdjustments, GeneralCase) {
// Connections.
ASSERT_TRUE(graph.AddConsumer(add1_node->id, input->id).ok());
Value<TensorRef<BHWC>>* link1;
Value* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, add1_node, quant_node, &link1).ok());
AddQuantParams(&link1->quant_params, /*min=*/0.0, /*max=*/2.0,
/*scale=*/0.008);
link1->tensor.shape = BHWC(1, 4, 4, 8);
ASSERT_TRUE(graph.AddConsumer(add2_node->id, link1->id).ok());
Value<TensorRef<BHWC>>* link2;
Value* link2;
ASSERT_TRUE(ConnectTwoNodes(&graph, quant_node, add2_node, &link2).ok());
AddQuantParams(&link2->quant_params, /*min=*/-1.0, /*max=*/1.0,
/*scale=*/0.008);
link2->tensor.shape = BHWC(1, 4, 4, 8);
Value<TensorRef<BHWC>>* output;
Value* output;
ASSERT_TRUE(AddOutput(&graph, add2_node, &output).ok());
AddQuantParams(&output->quant_params, /*min=*/-1.0, /*max=*/1.0,
/*scale=*/0.008);

View File

@ -57,11 +57,11 @@ TEST(MergeConvolutionWithAddTest, Smoke) {
ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok());
Value<TensorRef<BHWC>>* output;
Value* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRef<BHWC>>* link1;
Value* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, add_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16);
@ -108,11 +108,11 @@ TEST(MergeAddWithConvolutionTest, Smoke) {
ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
Value<TensorRef<BHWC>>* output;
Value* output;
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRef<BHWC>>* link1;
Value* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, add_node, conv_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16);

View File

@ -58,11 +58,11 @@ TEST(MergeConvolutionWithMulTest, Smoke) {
ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok());
Value<TensorRef<BHWC>>* output;
Value* output;
ASSERT_TRUE(AddOutput(&graph, mul_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRef<BHWC>>* link1;
Value* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, mul_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16);
@ -109,11 +109,11 @@ TEST(MergeMulWithConvolutionTest, Smoke) {
ASSERT_TRUE(graph.AddConsumer(mul_node->id, input->id).ok());
Value<TensorRef<BHWC>>* output;
Value* output;
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRef<BHWC>>* link1;
Value* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, mul_node, conv_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16);

View File

@ -68,16 +68,16 @@ TEST(MakeFullyConnected, Smoke) {
ASSERT_TRUE(graph.AddConsumer(conv1x1_node0->id, input->id).ok());
Value<TensorRef<BHWC>>* output;
Value* output;
ASSERT_TRUE(AddOutput(&graph, conv1x1_node2, &output).ok());
output->tensor.shape = BHWC(1, 1, 1, 32);
Value<TensorRef<BHWC>>* link1;
Value* link1;
ASSERT_TRUE(
ConnectTwoNodes(&graph, conv1x1_node0, conv4x4_node1, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRef<BHWC>>* link2;
Value* link2;
ASSERT_TRUE(
ConnectTwoNodes(&graph, conv4x4_node1, conv1x1_node2, &link2).ok());
link2->tensor.shape = BHWC(1, 1, 1, 16);

View File

@ -38,7 +38,7 @@ TEST(MakePadding, Smoke) {
attr.axis = Axis::HEIGHT;
concat_node->operation.attributes = attr;
Value<TensorRef<BHWC>>* output;
Value* output;
ASSERT_TRUE(AddOutput(&graph, concat_node, &output).ok());
output->tensor.shape = BHWC(1, 7, 3, 5);
@ -50,7 +50,7 @@ TEST(MakePadding, Smoke) {
std::vector<float>(const_attr.tensor.shape.DimensionsProduct(), 0);
const_node->operation.attributes = const_attr;
Value<TensorRef<BHWC>>* const_link;
Value* const_link;
ASSERT_TRUE(
ConnectTwoNodes(&graph, const_node, concat_node, &const_link).ok());
const_link->tensor.shape = const_attr.tensor.shape;

View File

@ -40,7 +40,7 @@ TEST(MergePaddingWith, Smoke) {
pad_node->operation.attributes = attr;
auto conv_node = graph.NewNode();
Value<TensorRef<BHWC>>* temp;
Value* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node, conv_node, &temp).ok());
ASSERT_TRUE(AddOutput(&graph, conv_node, &temp).ok());
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
@ -77,7 +77,7 @@ TEST(MergePaddingWith, MergeTwo) {
pad_node1->operation.attributes = attr;
auto pad_node2 = graph.NewNode();
Value<TensorRef<BHWC>>* temp;
Value* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node1, pad_node2, &temp).ok());
pad_node2->operation.type = ToString(OperationType::PAD);
attr.prepended = BHWC(0, 0, 0, 0);

View File

@ -35,12 +35,12 @@ TEST(RemoveSingleInputAdd, Smoke) {
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
auto add_node = graph.NewNode();
Value<TensorRef<BHWC>>* output;
Value* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
add_node->operation.type = ToString(OperationType::ADD);
add_node->operation.attributes = AddAttributes();
Value<TensorRef<BHWC>>* temp;
Value* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
ASSERT_EQ(2, graph.nodes().size());
ASSERT_EQ(3, graph.values().size());
@ -63,14 +63,14 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Tensor) {
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
auto add_node = graph.NewNode();
Value<TensorRef<BHWC>>* output;
Value* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
add_node->operation.type = ToString(OperationType::ADD);
AddAttributes attr;
attr.param = Tensor<Linear, DataType::FLOAT32>();
add_node->operation.attributes = attr;
Value<TensorRef<BHWC>>* temp;
Value* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
ASSERT_EQ(2, graph.nodes().size());
ASSERT_EQ(3, graph.values().size());
@ -90,14 +90,14 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Scalar) {
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
auto add_node = graph.NewNode();
Value<TensorRef<BHWC>>* output;
Value* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
add_node->operation.type = ToString(OperationType::ADD);
AddAttributes attr;
attr.param = 0.5f;
add_node->operation.attributes = attr;
Value<TensorRef<BHWC>>* temp;
Value* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
ASSERT_EQ(2, graph.nodes().size());
ASSERT_EQ(3, graph.values().size());
@ -119,11 +119,11 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Multiple) {
ASSERT_TRUE(graph.AddConsumer(node_b->id, input->id).ok());
auto add_node = graph.NewNode();
Value<TensorRef<BHWC>>* output;
Value* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
add_node->operation.type = ToString(OperationType::ADD);
Value<TensorRef<BHWC>>* temp;
Value* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, node_a, add_node, &temp).ok());
ASSERT_TRUE(ConnectTwoNodes(&graph, node_b, add_node, &temp).ok());
ASSERT_EQ(3, graph.nodes().size());
@ -144,7 +144,7 @@ TEST(RemoveDegenerateUpsampling, Smoke) {
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
auto node_to_remove = graph.NewNode();
Value<TensorRef<BHWC>>* output;
Value* output;
ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok());
output->tensor.shape = BHWC(1, 5, 5, 1);
node_to_remove->operation.type = ToString(OperationType::RESIZE);
@ -153,7 +153,7 @@ TEST(RemoveDegenerateUpsampling, Smoke) {
attr.type = SamplingType::BILINEAR;
node_to_remove->operation.attributes = attr;
Value<TensorRef<BHWC>>* link;
Value* link;
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok());
link->tensor.shape = output->tensor.shape;
ASSERT_EQ(2, graph.nodes().size());
@ -175,10 +175,10 @@ TEST(RemoveIdentityReshape, Smoke) {
Node* simple_node = graph.NewNode();
Node* producer_node = graph.NewNode();
Node* consumer_node = graph.NewNode();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* value0 = graph.NewValue();
Value<TensorRef<BHWC>>* value1 = graph.NewValue();
Value* graph_input = graph.NewValue();
Value* graph_output = graph.NewValue();
Value* value0 = graph.NewValue();
Value* value1 = graph.NewValue();
value0->tensor.shape = BHWC(1, 1, 1, 11);
simple_node->operation.type = ToString(OperationType::RESHAPE);

View File

@ -580,8 +580,7 @@ class InferenceBuilderImpl : public InferenceBuilder {
private:
// Links internal tensors with external user-facing objects.
std::vector<TensorTieDef> LinkTensors(
const std::vector<Value<TensorRef<BHWC>>*>& values) {
std::vector<TensorTieDef> LinkTensors(const std::vector<Value*>& values) {
std::vector<TensorTieDef> links;
links.reserve(values.size());
for (const auto& value : values) {

View File

@ -145,7 +145,7 @@ class Delegate {
// TODO(impjdi): Remove code duplication.
auto values = graph.values();
auto find_value = [&](int tensor_index) -> Value<TensorRef<BHWC>>* {
auto find_value = [&](int tensor_index) -> Value* {
for (auto value : values) {
if (value->tensor.ref == tensor_index) return value;
}

View File

@ -54,20 +54,17 @@ namespace gpu {
namespace metal {
namespace {
bool IsWidthBroadcastedForSecondInput(
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
bool IsWidthBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w &&
inputs[1]->tensor.shape.w == 1;
}
bool IsHeightBroadcastedForSecondInput(
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
bool IsHeightBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h &&
inputs[1]->tensor.shape.h == 1;
}
bool IsChannelsBroadcastedForSecondInput(
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
bool IsChannelsBroadcastedForSecondInput(const std::vector<Value*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c &&
inputs[1]->tensor.shape.c == 1;

View File

@ -239,7 +239,7 @@ class Delegate {
// TODO(impjdi): Remove code duplication.
auto values = graph.values();
auto find_value = [&](int tensor_index) -> Value<TensorRef<BHWC>>* {
auto find_value = [&](int tensor_index) -> Value* {
for (auto value : values) {
if (value->tensor.ref == tensor_index) return value;
}