Get rid of TensorRefFloat32 specialization. It no longer keeps type in template params, therefore, it is better to use ref explicitly.

PiperOrigin-RevId: 253738844
This commit is contained in:
A. Unique TensorFlower 2019-06-18 00:38:53 -07:00 committed by TensorFlower Gardener
parent 3557034977
commit 5044a18831
34 changed files with 207 additions and 207 deletions

View File

@ -68,12 +68,12 @@ int64_t DimensionsProduct(const TfLiteIntArray& dims) {
// will turn into:
// node(copy(output)) <- passthrough_node(output)
Status NewPassthroughNode(GraphFloat32* graph, Node* node,
const Value<TensorRefFloat32>* output,
const Value<TensorRef<BHWC>>* output,
Node** passthru_node) {
*passthru_node = graph->NewNode();
// Make copies for every output in the original node.
RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id));
Value<TensorRefFloat32>* copy_output = graph->NewValue();
Value<TensorRef<BHWC>>* copy_output = graph->NewValue();
RETURN_IF_ERROR(graph->SetProducer(node->id, copy_output->id));
RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id));
copy_output->tensor = output->tensor;
@ -265,13 +265,13 @@ class ObjectReader {
public:
ObjectReader(GraphFloat32* graph, TfLiteContext* context,
const TfLiteNode* tflite_node,
std::vector<Value<TensorRefFloat32>*>* tensor_to_value)
std::vector<Value<TensorRef<BHWC>>*>* tensor_to_value)
: graph_(graph),
context_(context),
tflite_node_(tflite_node),
tensor_to_value_(tensor_to_value) {}
Status ReadValue(uint32_t idx, Value<TensorRefFloat32>** value) {
Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value) {
if (idx >= tflite_node_->inputs->size) {
return OutOfRangeError(StrCat("ReadValue: input tensor index: ", idx));
}
@ -319,7 +319,7 @@ class ObjectReader {
tflite_node_->outputs->size));
}
int output_tensor_idx = tflite_node_->outputs->data[id];
Value<TensorRefFloat32>* value;
Value<TensorRef<BHWC>>* value;
RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value));
RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id));
return OkStatus();
@ -333,13 +333,13 @@ class ObjectReader {
}
Status AddInput(const Node* node, uint32_t idx) {
Value<TensorRefFloat32>* input;
Value<TensorRef<BHWC>>* input;
RETURN_IF_ERROR(ReadValue(idx, &input));
return graph_->AddConsumer(node->id, input->id);
}
Status ReadValueByTensorIdx(uint32_t tensor_idx,
Value<TensorRefFloat32>** value) {
Value<TensorRef<BHWC>>** value) {
if (tensor_idx >= tensor_to_value_->size()) {
return OutOfRangeError(
StrCat("ReadValue: input tensor index: ", tensor_idx));
@ -350,7 +350,7 @@ class ObjectReader {
return NotFoundError(
StrCat("ReadValue: value is a constant tensor: ", tensor_idx));
}
Value<TensorRefFloat32>* value = graph_->NewValue();
Value<TensorRef<BHWC>>* value = graph_->NewValue();
RETURN_IF_ERROR(
ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor));
value->tensor.ref = tensor_idx;
@ -364,7 +364,7 @@ class ObjectReader {
GraphFloat32* graph_ = nullptr;
const TfLiteContext* context_ = nullptr;
const TfLiteNode* tflite_node_ = nullptr;
std::vector<Value<TensorRefFloat32>*>* tensor_to_value_;
std::vector<Value<TensorRef<BHWC>>*>* tensor_to_value_;
};
Status CheckInputsOutputs(const TfLiteContext* context,
@ -639,7 +639,7 @@ class Conv2DOperationParser : public TFLiteOperationParser {
// Creates a simple node that holds tensor value.
Status NewConstNode(TensorFloat32 t, GraphFloat32* graph,
Value<TensorRefFloat32>** value) {
Value<TensorRef<BHWC>>** value) {
ConstTensorAttributes attr;
attr.tensor = std::move(t);
Node* node = graph->NewNode();
@ -677,16 +677,16 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
ConcatAttributes attr;
// Read inputs first to make sure const node is added to a graph before
// concat node to ensure topological order.
std::vector<const Value<TensorRefFloat32>*> inputs;
std::vector<const Value<TensorRef<BHWC>>*> inputs;
for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
Value<TensorRefFloat32>* value;
Value<TensorRef<BHWC>>* value;
const auto status = reader->ReadValue(idx, &value);
if (status.ok()) {
inputs.push_back(value);
} else {
TensorFloat32 tensor;
RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
Value<TensorRefFloat32>* value;
Value<TensorRef<BHWC>>* value;
RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
inputs.push_back(value);
}
@ -695,7 +695,7 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONCAT);
RETURN_IF_ERROR(reader->AddOutputs(node));
for (const Value<TensorRefFloat32>* input : inputs) {
for (const Value<TensorRef<BHWC>>* input : inputs) {
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
}
@ -1143,11 +1143,11 @@ class LstmOperationParser : public TFLiteOperationParser {
lstm_attr.kernel_type = LstmKernelType::BASIC;
lstm_node->operation.attributes = lstm_attr;
Value<TensorRefFloat32>* concat_temp;
Value<TensorRef<BHWC>>* concat_temp;
int concat_tensor_idx = tflite_node->outputs->data[2];
RETURN_IF_ERROR(
reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
Value<TensorRefFloat32>* activ_temp;
Value<TensorRef<BHWC>>* activ_temp;
int activ_tensor_idx = tflite_node->outputs->data[3];
RETURN_IF_ERROR(
reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
@ -1521,7 +1521,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser {
if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
auto& reshape = node;
conv = graph->NewNode(); // reset conv pointer!
Value<TensorRefFloat32>* reshaped_value = graph->NewValue();
Value<TensorRef<BHWC>>* reshaped_value = graph->NewValue();
reshaped_value->tensor.shape = BHWC(1, 1, 1, weights.shape.w);
RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id));
reshape->operation.type = ToString(OperationType::RESHAPE);
@ -1558,7 +1558,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SLICE);
RETURN_IF_ERROR(reader->AddOutputs(node));
Value<TensorRefFloat32>* input;
Value<TensorRef<BHWC>>* input;
RETURN_IF_ERROR(reader->ReadValue(0, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
@ -1721,7 +1721,7 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
ObjectReader* reader) final {
auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
Value<TensorRefFloat32>* input;
Value<TensorRef<BHWC>>* input;
RETURN_IF_ERROR(reader->ReadValue(2, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
RETURN_IF_ERROR(reader->AddOutputs(node));
@ -1970,7 +1970,7 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
} // namespace
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
TensorRefFloat32* tensor_ref) {
TensorRef<BHWC>* tensor_ref) {
tensor_ref->type = ToDataType(tflite_tensor.type);
const TfLiteIntArray* dims = tflite_tensor.dims;
switch (dims->size) {
@ -2128,8 +2128,8 @@ Status BuildModel(TfLiteContext* context,
}
operations.push_back(std::move(op_parser));
}
std::vector<Value<TensorRefFloat32>*> tensor_to_value(context->tensors_size,
nullptr);
std::vector<Value<TensorRef<BHWC>>*> tensor_to_value(context->tensors_size,
nullptr);
for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
TfLiteNode* tflite_node = nullptr;
TfLiteRegistration* registration = nullptr;

View File

@ -38,7 +38,7 @@ Status BuildModel(TfLiteContext* context,
// Module-internal converter, exposed for unit testing purpose only.
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
TensorRefFloat32* tensor_ref);
TensorRef<BHWC>* tensor_ref);
} // namespace gpu
} // namespace tflite

View File

@ -34,7 +34,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank0) {
tflite_tensor.type = TfLiteType::kTfLiteFloat32;
tflite_tensor.dims = TfLiteIntArrayCreate(1);
tflite_tensor.dims->data[0] = 4;
TensorRefFloat32 tensor_ref;
TensorRef<BHWC> tensor_ref;
const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims);
@ -49,7 +49,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank1) {
tflite_tensor.dims = TfLiteIntArrayCreate(2);
tflite_tensor.dims->data[0] = 4;
tflite_tensor.dims->data[1] = 5;
TensorRefFloat32 tensor_ref;
TensorRef<BHWC> tensor_ref;
const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims);
@ -65,7 +65,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank2) {
tflite_tensor.dims->data[0] = 4;
tflite_tensor.dims->data[1] = 5;
tflite_tensor.dims->data[2] = 6;
TensorRefFloat32 tensor_ref;
TensorRef<BHWC> tensor_ref;
const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims);
@ -82,7 +82,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank3) {
tflite_tensor.dims->data[1] = 5;
tflite_tensor.dims->data[2] = 6;
tflite_tensor.dims->data[3] = 7;
TensorRefFloat32 tensor_ref;
TensorRef<BHWC> tensor_ref;
const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims);
@ -95,7 +95,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankLT0) {
TfLiteTensor tflite_tensor;
tflite_tensor.type = TfLiteType::kTfLiteFloat32;
tflite_tensor.dims = TfLiteIntArrayCreate(0);
TensorRefFloat32 tensor_ref;
TensorRef<BHWC> tensor_ref;
const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims);
@ -107,7 +107,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankGT3) {
TfLiteTensor tflite_tensor;
tflite_tensor.type = TfLiteType::kTfLiteFloat32;
tflite_tensor.dims = TfLiteIntArrayCreate(5);
TensorRefFloat32 tensor_ref;
TensorRef<BHWC> tensor_ref;
const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims);

View File

@ -31,8 +31,8 @@ TEST(Model, SingleNode) {
// graph_input -> node -> graph_output
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
@ -52,9 +52,9 @@ TEST(Model, SingleNodeMultipleOutputs) {
// graph_input -> node -> (graph_output1, graph_output2)
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output1 = graph.NewValue();
Value<TensorRefFloat32>* graph_output2 = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output1 = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output2 = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node->id, graph_output1->id).ok());
ASSERT_TRUE(graph.SetProducer(node->id, graph_output2->id).ok());
@ -67,7 +67,7 @@ TEST(Model, SingleNodeMultipleOutputs) {
TEST(Model, SetSameConsumer) {
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
EXPECT_FALSE(graph.AddConsumer(node->id, graph_input->id).ok());
}
@ -76,8 +76,8 @@ TEST(Model, RemoveConsumer) {
// (graph_input1, graph_input2) -> node
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_input1 = graph.NewValue();
Value<TensorRefFloat32>* graph_input2 = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input1 = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input2 = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input1->id).ok());
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input2->id).ok());
EXPECT_THAT(graph.FindConsumers(graph_input1->id),
@ -101,7 +101,7 @@ TEST(Model, RemoveConsumer) {
TEST(Model, SetSameProducer) {
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
EXPECT_FALSE(graph.SetProducer(node->id, graph_output->id).ok());
}
@ -109,7 +109,7 @@ TEST(Model, SetSameProducer) {
TEST(Model, RemoveProducer) {
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
EXPECT_THAT(graph.inputs(), UnorderedElementsAre());
@ -126,8 +126,8 @@ TEST(Model, RemoveProducer) {
TEST(Model, RemoveSimpleNodeDegenerateCase) {
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
@ -145,9 +145,9 @@ TEST(Model, RemoveSimpleNodeNoPreviousNode) {
GraphFloat32 graph;
Node* simple_node = graph.NewNode();
Node* consumer_node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue();
Value<TensorRefFloat32>* value = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(simple_node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(simple_node->id, value->id).ok());
@ -167,9 +167,9 @@ TEST(Model, RemoveSimpleNodeNoAfterNodes) {
GraphFloat32 graph;
Node* simple_node = graph.NewNode();
Node* producer_node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue();
Value<TensorRefFloat32>* value = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(simple_node->id, value->id).ok());
ASSERT_TRUE(graph.SetProducer(simple_node->id, graph_output->id).ok());
@ -190,10 +190,10 @@ TEST(Model, RemoveSimpleNodeGeneralCase) {
Node* simple_node = graph.NewNode();
Node* producer_node = graph.NewNode();
Node* consumer_node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue();
Value<TensorRefFloat32>* value0 = graph.NewValue();
Value<TensorRefFloat32>* value1 = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* value0 = graph.NewValue();
Value<TensorRef<BHWC>>* value1 = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(producer_node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(producer_node->id, value0->id).ok());
@ -217,14 +217,14 @@ TEST(Model, CircularDependency) {
{
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRefFloat32>* value = graph.NewValue();
Value<TensorRef<BHWC>>* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, value->id).ok());
EXPECT_FALSE(graph.SetProducer(node->id, value->id).ok());
}
{
GraphFloat32 graph;
Node* node = graph.NewNode();
Value<TensorRefFloat32>* value = graph.NewValue();
Value<TensorRef<BHWC>>* value = graph.NewValue();
ASSERT_TRUE(graph.SetProducer(node->id, value->id).ok());
EXPECT_FALSE(graph.AddConsumer(node->id, value->id).ok());
}
@ -237,8 +237,8 @@ TEST(Model, ReassignValue) {
GraphFloat32 graph;
Node* node1 = graph.NewNode();
Node* node2 = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node1->id, graph_output->id).ok());
ASSERT_TRUE(graph.AddConsumer(node2->id, graph_input->id).ok());
@ -264,9 +264,9 @@ TEST(Model, DeleteValue) {
GraphFloat32 graph;
Node* node1 = graph.NewNode();
Node* node2 = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue();
Value<TensorRefFloat32>* value = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
@ -305,10 +305,10 @@ TEST(Model, DeleteNode) {
Node* node1 = graph.NewNode();
Node* node2 = graph.NewNode();
Node* node3 = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue();
Value<TensorRefFloat32>* graph_output2 = graph.NewValue();
Value<TensorRefFloat32>* value = graph.NewValue();
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRef<BHWC>>* graph_output2 = graph.NewValue();
Value<TensorRef<BHWC>>* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());

View File

@ -57,11 +57,11 @@ TEST(MergeConvolutionWithAddTest, Smoke) {
ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok());
Value<TensorRefFloat32>* output;
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRefFloat32>* link1;
Value<TensorRef<BHWC>>* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, add_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16);
@ -108,11 +108,11 @@ TEST(MergeAddWithConvolutionTest, Smoke) {
ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
Value<TensorRefFloat32>* output;
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRefFloat32>* link1;
Value<TensorRef<BHWC>>* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, add_node, conv_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16);

View File

@ -58,11 +58,11 @@ TEST(MergeConvolutionWithMulTest, Smoke) {
ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok());
Value<TensorRefFloat32>* output;
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, mul_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRefFloat32>* link1;
Value<TensorRef<BHWC>>* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, mul_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16);
@ -109,11 +109,11 @@ TEST(MergeMulWithConvolutionTest, Smoke) {
ASSERT_TRUE(graph.AddConsumer(mul_node->id, input->id).ok());
Value<TensorRefFloat32>* output;
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRefFloat32>* link1;
Value<TensorRef<BHWC>>* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, mul_node, conv_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16);

View File

@ -68,16 +68,16 @@ TEST(MakeFullyConnected, Smoke) {
ASSERT_TRUE(graph.AddConsumer(conv1x1_node0->id, input->id).ok());
Value<TensorRefFloat32>* output;
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, conv1x1_node2, &output).ok());
output->tensor.shape = BHWC(1, 1, 1, 32);
Value<TensorRefFloat32>* link1;
Value<TensorRef<BHWC>>* link1;
ASSERT_TRUE(
ConnectTwoNodes(&graph, conv1x1_node0, conv4x4_node1, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRefFloat32>* link2;
Value<TensorRef<BHWC>>* link2;
ASSERT_TRUE(
ConnectTwoNodes(&graph, conv4x4_node1, conv1x1_node2, &link2).ok());
link2->tensor.shape = BHWC(1, 1, 1, 16);

View File

@ -38,7 +38,7 @@ TEST(MakePadding, Smoke) {
attr.axis = Axis::HEIGHT;
concat_node->operation.attributes = attr;
Value<TensorRefFloat32>* output;
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, concat_node, &output).ok());
output->tensor.shape = BHWC(1, 7, 3, 5);
@ -50,7 +50,7 @@ TEST(MakePadding, Smoke) {
std::vector<float>(const_attr.tensor.shape.DimensionsProduct(), 0);
const_node->operation.attributes = const_attr;
Value<TensorRefFloat32>* const_link;
Value<TensorRef<BHWC>>* const_link;
ASSERT_TRUE(
ConnectTwoNodes(&graph, const_node, concat_node, &const_link).ok());
const_link->tensor.shape = const_attr.tensor.shape;

View File

@ -62,15 +62,15 @@ TEST(MatchDilatedConvolutionTest, MakesDilatedConvolution) {
ASSERT_TRUE(graph.AddConsumer(sb_node->id, input->id).ok());
Value<TensorRefFloat32>* output;
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, bs_node, &output).ok());
output->tensor.shape = BHWC(1, 95, 1, 17);
Value<TensorRefFloat32>* sb_link;
Value<TensorRef<BHWC>>* sb_link;
ASSERT_TRUE(ConnectTwoNodes(&graph, sb_node, dw_node, &sb_link).ok());
sb_link->tensor.shape = BHWC(21, 128, 1, 17);
Value<TensorRefFloat32>* bs_link;
Value<TensorRef<BHWC>>* bs_link;
ASSERT_TRUE(ConnectTwoNodes(&graph, dw_node, bs_node, &bs_link).ok());
bs_link->tensor.shape = BHWC(1, 95, 1, 17);

View File

@ -40,7 +40,7 @@ TEST(MergePaddingWith, Smoke) {
pad_node->operation.attributes = attr;
auto conv_node = graph.NewNode();
Value<TensorRefFloat32>* temp;
Value<TensorRef<BHWC>>* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node, conv_node, &temp).ok());
ASSERT_TRUE(AddOutput(&graph, conv_node, &temp).ok());
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
@ -77,7 +77,7 @@ TEST(MergePaddingWith, MergeTwo) {
pad_node1->operation.attributes = attr;
auto pad_node2 = graph.NewNode();
Value<TensorRefFloat32>* temp;
Value<TensorRef<BHWC>>* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node1, pad_node2, &temp).ok());
pad_node2->operation.type = ToString(OperationType::PAD);
attr.prepended = HWC(0, 0, 0);

View File

@ -33,12 +33,12 @@ TEST(RemoveSingleInputAdd, Smoke) {
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
auto add_node = graph.NewNode();
Value<TensorRefFloat32>* output;
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
add_node->operation.type = ToString(OperationType::ADD);
add_node->operation.attributes = AddAttributes();
Value<TensorRefFloat32>* temp;
Value<TensorRef<BHWC>>* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
ASSERT_EQ(2, graph.nodes().size());
ASSERT_EQ(3, graph.values().size());
@ -61,14 +61,14 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Tensor) {
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
auto add_node = graph.NewNode();
Value<TensorRefFloat32>* output;
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
add_node->operation.type = ToString(OperationType::ADD);
AddAttributes attr;
attr.param = Tensor<Linear, DataType::FLOAT32>();
add_node->operation.attributes = attr;
Value<TensorRefFloat32>* temp;
Value<TensorRef<BHWC>>* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
ASSERT_EQ(2, graph.nodes().size());
ASSERT_EQ(3, graph.values().size());
@ -90,11 +90,11 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Multiple) {
ASSERT_TRUE(graph.AddConsumer(node_b->id, input->id).ok());
auto add_node = graph.NewNode();
Value<TensorRefFloat32>* output;
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
add_node->operation.type = ToString(OperationType::ADD);
Value<TensorRefFloat32>* temp;
Value<TensorRef<BHWC>>* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, node_a, add_node, &temp).ok());
ASSERT_TRUE(ConnectTwoNodes(&graph, node_b, add_node, &temp).ok());
ASSERT_EQ(3, graph.nodes().size());
@ -115,7 +115,7 @@ TEST(RemoveDegenerateUpsampling, Smoke) {
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
auto node_to_remove = graph.NewNode();
Value<TensorRefFloat32>* output;
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok());
output->tensor.shape = BHWC(1, 5, 5, 1);
node_to_remove->operation.type = ToString(OperationType::UPSAMPLE_2D);
@ -124,7 +124,7 @@ TEST(RemoveDegenerateUpsampling, Smoke) {
attr.type = UpsamplingType::BILINEAR;
node_to_remove->operation.attributes = attr;
Value<TensorRefFloat32>* link;
Value<TensorRef<BHWC>>* link;
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok());
link->tensor.shape = output->tensor.shape;
ASSERT_EQ(2, graph.nodes().size());
@ -148,7 +148,7 @@ TEST(RemoveIdentityReshape, Smoke) {
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
auto node_to_remove = graph.NewNode();
Value<TensorRefFloat32>* output;
Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok());
output->tensor.shape = BHWC(1, 1, 1, 11);
node_to_remove->operation.type = ToString(OperationType::RESHAPE);
@ -156,7 +156,7 @@ TEST(RemoveIdentityReshape, Smoke) {
attr.new_shape = BHWC(1, 1, 1, 11);
node_to_remove->operation.attributes = attr;
Value<TensorRefFloat32>* link;
Value<TensorRef<BHWC>>* link;
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok());
link->tensor.shape = output->tensor.shape;
ASSERT_EQ(2, graph.nodes().size());

View File

@ -31,7 +31,7 @@ namespace gl {
namespace {
TEST(AddTest, TwoInputTensorsOfTheSameShape) {
TensorRefFloat32 augend, addend, output;
TensorRef<BHWC> augend, addend, output;
augend.type = DataType::FLOAT32;
augend.ref = 0;
augend.shape = BHWC(1, 2, 2, 1);
@ -57,7 +57,7 @@ TEST(AddTest, TwoInputTensorsOfTheSameShape) {
TEST(AddTest, InputTensorAndScalar) {
AddAttributes attr;
attr.param = 0.1f;
TensorRefFloat32 input, output;
TensorRef<BHWC> input, output;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 1, 2);
@ -75,7 +75,7 @@ TEST(AddTest, InputTensorAndScalar) {
}
TEST(AddTest, InputTensorWithConstandBroadcast) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 2);
@ -88,7 +88,7 @@ TEST(AddTest, InputTensorWithConstandBroadcast) {
tensor.data.push_back(20.0);
attr.param = std::move(tensor);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 2, 2, 2);
@ -104,19 +104,19 @@ TEST(AddTest, InputTensorWithConstandBroadcast) {
}
TEST(AddTest, InputTensorWithRuntimeBroadcast) {
TensorRefFloat32 input1;
TensorRef<BHWC> input1;
input1.type = DataType::FLOAT32;
input1.ref = 0;
input1.shape = BHWC(1, 2, 2, 2);
TensorRefFloat32 input2;
TensorRef<BHWC> input2;
input2.type = DataType::FLOAT32;
input2.ref = 1;
input2.shape = BHWC(1, 1, 1, 2);
AddAttributes attr;
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 2, 2, 2);

View File

@ -31,7 +31,7 @@ namespace gl {
namespace {
TEST(ConcatTest, TwoInputTensorsByUnalignedChannel) {
TensorRefFloat32 input1, input2, output;
TensorRef<BHWC> input1, input2, output;
input1.type = DataType::FLOAT32;
input1.ref = 0;
input1.shape = BHWC(1, 2, 2, 1);
@ -57,7 +57,7 @@ TEST(ConcatTest, TwoInputTensorsByUnalignedChannel) {
}
TEST(ConcatTest, TwoInputTensorsByAlignedChannel) {
TensorRefFloat32 input1, input2, output;
TensorRef<BHWC> input1, input2, output;
input1.type = DataType::FLOAT32;
input1.ref = 0;
input1.shape = BHWC(1, 1, 1, 4);
@ -83,7 +83,7 @@ TEST(ConcatTest, TwoInputTensorsByAlignedChannel) {
}
TEST(ConcatTest, TwoInputTensorsByHeight) {
TensorRefFloat32 input1, input2, output;
TensorRef<BHWC> input1, input2, output;
input1.type = DataType::FLOAT32;
input1.ref = 0;
input1.shape = BHWC(1, 1, 2, 1);
@ -109,7 +109,7 @@ TEST(ConcatTest, TwoInputTensorsByHeight) {
}
TEST(ConcatTest, TwoInputTensorsByWidth) {
TensorRefFloat32 input1, input2, output;
TensorRef<BHWC> input1, input2, output;
input1.type = DataType::FLOAT32;
input1.ref = 0;
input1.shape = BHWC(1, 2, 1, 1);

View File

@ -31,7 +31,7 @@ namespace gl {
namespace {
TEST(ConvTest, O2H2W1I1Stride1x1Dilation1x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
@ -54,7 +54,7 @@ TEST(ConvTest, O2H2W1I1Stride1x1Dilation1x1) {
attr.padding.appended = HW(1, 0);
attr.strides = HW(1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 2, 2, 2);
@ -69,7 +69,7 @@ TEST(ConvTest, O2H2W1I1Stride1x1Dilation1x1) {
}
TEST(ConvTest, O1H2W2I1Stride1x1Dilation2x2) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 3, 1);
@ -92,7 +92,7 @@ TEST(ConvTest, O1H2W2I1Stride1x1Dilation2x2) {
attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 1, 1, 1);
@ -106,7 +106,7 @@ TEST(ConvTest, O1H2W2I1Stride1x1Dilation2x2) {
}
TEST(ConvTest, O1H3W3I1Stride1x1Dilation1x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
@ -129,7 +129,7 @@ TEST(ConvTest, O1H3W3I1Stride1x1Dilation1x1) {
attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 1, 1, 1);
@ -143,7 +143,7 @@ TEST(ConvTest, O1H3W3I1Stride1x1Dilation1x1) {
}
TEST(ConvTest, O2H1W1I2Stride1x1Dilation1x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 1, 2);
@ -166,7 +166,7 @@ TEST(ConvTest, O2H1W1I2Stride1x1Dilation1x1) {
attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 2, 1, 2);
@ -180,7 +180,7 @@ TEST(ConvTest, O2H1W1I2Stride1x1Dilation1x1) {
}
TEST(ConvTest, O1H1W1I1Stride2x2Dilation1x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 3, 1);
@ -204,7 +204,7 @@ TEST(ConvTest, O1H1W1I1Stride2x2Dilation1x1) {
attr.padding.appended = HW(0, 0);
attr.strides = HW(2, 2);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 2, 2, 1);

View File

@ -31,7 +31,7 @@ namespace gl {
namespace {
TEST(DepthwiseConvTest, O4H1W1I2Strides1x1Dilation1x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 1, 2);
@ -55,7 +55,7 @@ TEST(DepthwiseConvTest, O4H1W1I2Strides1x1Dilation1x1) {
attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 1, 1, 4);
@ -69,7 +69,7 @@ TEST(DepthwiseConvTest, O4H1W1I2Strides1x1Dilation1x1) {
}
TEST(DepthwiseConvTest, O2H1W1I1Strides2x2Dilation1x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 3, 1);
@ -93,7 +93,7 @@ TEST(DepthwiseConvTest, O2H1W1I1Strides2x2Dilation1x1) {
attr.padding.appended = HW(0, 0);
attr.strides = HW(2, 2);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 2, 2, 2);
@ -108,7 +108,7 @@ TEST(DepthwiseConvTest, O2H1W1I1Strides2x2Dilation1x1) {
}
TEST(DepthwiseConvTest, O2H2W2I1Strides1x1Dilation2x2) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 3, 1);
@ -132,7 +132,7 @@ TEST(DepthwiseConvTest, O2H2W2I1Strides1x1Dilation2x2) {
attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 1, 1, 2);

View File

@ -33,8 +33,8 @@ class ElementwiseOneArgumentTest : public ::testing::Test {
ElementwiseOneArgumentTest() = default;
~ElementwiseOneArgumentTest() override = default;
TensorRefFloat32 GetTensorRef(int ref) {
TensorRefFloat32 tensor_ref;
TensorRef<BHWC> GetTensorRef(int ref) {
TensorRef<BHWC> tensor_ref;
tensor_ref.type = DataType::FLOAT32;
tensor_ref.ref = ref;
tensor_ref.shape = BHWC(1, 2, 2, 1);
@ -137,8 +137,8 @@ class ElementwiseTwoArgumentsTest : public ::testing::Test {
ElementwiseTwoArgumentsTest() = default;
~ElementwiseTwoArgumentsTest() override = default;
TensorRefFloat32 GetTensorRef(int ref) {
TensorRefFloat32 tensor_ref;
TensorRef<BHWC> GetTensorRef(int ref) {
TensorRef<BHWC> tensor_ref;
tensor_ref.type = DataType::FLOAT32;
tensor_ref.ref = ref;
tensor_ref.shape = BHWC(1, 2, 2, 1);

View File

@ -31,7 +31,7 @@ namespace gl {
namespace {
TEST(FullyConnectedTest, MatrixByVectorMultiplication) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 1, 2);
@ -50,7 +50,7 @@ TEST(FullyConnectedTest, MatrixByVectorMultiplication) {
weights.data = {1, 2, 3, 4, 5, 6, 7, 8};
attr.weights = std::move(weights);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 1, 1, 4);

View File

@ -31,22 +31,22 @@ namespace gl {
namespace {
TEST(LstmTest, Input2x2x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 prev_state;
TensorRef<BHWC> prev_state;
prev_state.type = DataType::FLOAT32;
prev_state.ref = 1;
prev_state.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output_state;
TensorRef<BHWC> output_state;
output_state.type = DataType::FLOAT32;
output_state.ref = 2;
output_state.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output_activation;
TensorRef<BHWC> output_activation;
output_activation.type = DataType::FLOAT32;
output_activation.ref = 3;
output_activation.shape = BHWC(1, 2, 2, 1);

View File

@ -31,17 +31,17 @@ namespace gl {
namespace {
TEST(MaxUnpoolingTest, Kernel2x2Stride2x2) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 indices;
TensorRef<BHWC> indices;
indices.type = DataType::INT32;
indices.ref = 1;
indices.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 4, 4, 1);

View File

@ -31,12 +31,12 @@ namespace gl {
namespace {
TEST(MulTest, Scalar) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);
@ -51,12 +51,12 @@ TEST(MulTest, Scalar) {
}
TEST(MulTest, Linear) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 2, 2);
@ -75,17 +75,17 @@ TEST(MulTest, Linear) {
}
TEST(ApplyMaskTest, MaskChannel1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 mask;
TensorRef<BHWC> mask;
mask.type = DataType::FLOAT32;
mask.ref = 1;
mask.shape = BHWC(1, 1, 2, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 1, 2, 2);
@ -99,17 +99,17 @@ TEST(ApplyMaskTest, MaskChannel1) {
}
TEST(ApplyMaskTest, MaskChannelEqualsToInputChannel) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 mask;
TensorRef<BHWC> mask;
mask.type = DataType::FLOAT32;
mask.ref = 1;
mask.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 1, 2, 2);

View File

@ -34,12 +34,12 @@ namespace {
void TestPadOperation(const HWC& prepend, const HWC& append,
const BHWC& output_shape, std::vector<float>&& expected) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = output_shape;

View File

@ -35,12 +35,12 @@ namespace gl {
namespace {
TEST(PoolingTest, MaxKernel2x2Stride2x2WithIndices) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 4, 4, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);
@ -70,12 +70,12 @@ TEST(PoolingTest, MaxKernel2x2Stride2x2WithIndices) {
}
TEST(PoolingTest, MaxKernel2x2Stride2x2WithoutIndices) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 4, 4, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);
@ -96,12 +96,12 @@ TEST(PoolingTest, MaxKernel2x2Stride2x2WithoutIndices) {
}
TEST(PoolingTest, AverageKernel2x2Stride2x2) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 4, 4, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);

View File

@ -29,7 +29,7 @@ namespace gl {
namespace {
TEST(PReluTest, LinearAlphaNoClip) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
@ -42,7 +42,7 @@ TEST(PReluTest, LinearAlphaNoClip) {
alpha.data = {2};
attr.alpha = std::move(alpha);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 2, 2, 1);
@ -55,7 +55,7 @@ TEST(PReluTest, LinearAlphaNoClip) {
}
TEST(PReluTest, LinearAlphaWithClip) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
@ -68,7 +68,7 @@ TEST(PReluTest, LinearAlphaWithClip) {
alpha.data = {2};
attr.alpha = std::move(alpha);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 2, 2, 1);
@ -81,7 +81,7 @@ TEST(PReluTest, LinearAlphaWithClip) {
}
TEST(PReluTest, 3DAlphaNoClip) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
@ -95,7 +95,7 @@ TEST(PReluTest, 3DAlphaNoClip) {
alpha.data = {1, 2, 2, 2};
attr.alpha = std::move(alpha);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 2, 2, 1);
@ -107,7 +107,7 @@ TEST(PReluTest, 3DAlphaNoClip) {
}
TEST(PReluTest, 3DAlphaWithClip) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
@ -121,7 +121,7 @@ TEST(PReluTest, 3DAlphaWithClip) {
alpha.data = {1, 2, 2, 2};
attr.alpha = std::move(alpha);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 2;
output.shape = BHWC(1, 2, 2, 1);

View File

@ -33,8 +33,8 @@ class ReluTest : public ::testing::Test {
ReluTest() = default;
~ReluTest() override = default;
TensorRefFloat32 GetTensorRef(int ref) {
TensorRefFloat32 tensor_ref;
TensorRef<BHWC> GetTensorRef(int ref) {
TensorRef<BHWC> tensor_ref;
tensor_ref.type = DataType::FLOAT32;
tensor_ref.ref = ref;
tensor_ref.shape = BHWC(1, 2, 2, 1);

View File

@ -31,12 +31,12 @@ namespace gl {
namespace {
TEST(Reshape, 1x2x3To3x2x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 3);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 3, 2, 1);
@ -53,12 +53,12 @@ TEST(Reshape, 1x2x3To3x2x1) {
}
TEST(Reshape, 3x1x2To2x1x3) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 1, 2);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 1, 3);
@ -75,12 +75,12 @@ TEST(Reshape, 3x1x2To2x1x3) {
}
TEST(Reshape, 1x1x4To2x2x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 1, 4);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);
@ -96,12 +96,12 @@ TEST(Reshape, 1x1x4To2x2x1) {
}
TEST(Reshape, BatchIsUnsupported) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(4, 1, 1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);

View File

@ -31,12 +31,12 @@ namespace gl {
namespace {
TEST(SliceTest, Identity) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 2, 2);
@ -54,12 +54,12 @@ TEST(SliceTest, Identity) {
}
TEST(SliceTest, NegativeEnds) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 2, 2);
@ -77,12 +77,12 @@ TEST(SliceTest, NegativeEnds) {
}
TEST(SliceTest, NegativeEndsNonZeroStarts) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 1, 1);
@ -100,12 +100,12 @@ TEST(SliceTest, NegativeEndsNonZeroStarts) {
}
TEST(SliceTest, StridesByHeight) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 4, 1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 1, 1);
@ -123,12 +123,12 @@ TEST(SliceTest, StridesByHeight) {
}
TEST(SliceTest, StridesByWidth) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 4, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 2, 1);
@ -146,12 +146,12 @@ TEST(SliceTest, StridesByWidth) {
}
TEST(SliceTest, StridesByChannels) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 1, 4);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 1, 1);

View File

@ -32,12 +32,12 @@ namespace gl {
namespace {
TEST(SoftmaxTest, WorksForChannelsAxis) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);
@ -53,12 +53,12 @@ TEST(SoftmaxTest, WorksForChannelsAxis) {
}
TEST(SoftmaxTest, DoesNotWorkForHeightAxis) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);
@ -75,12 +75,12 @@ TEST(SoftmaxTest, DoesNotWorkForHeightAxis) {
}
TEST(SoftmaxTest, DoesNotWorkForWidthAxis) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);

View File

@ -37,8 +37,8 @@ namespace gpu {
namespace gl {
SingleOpModel::SingleOpModel(Operation&& operation,
const std::vector<TensorRefFloat32>& inputs,
const std::vector<TensorRefFloat32>& outputs) {
const std::vector<TensorRef<BHWC>>& inputs,
const std::vector<TensorRef<BHWC>>& outputs) {
auto node = graph_.NewNode();
node->operation = std::move(operation);

View File

@ -41,8 +41,8 @@ class SingleOpModel {
public:
SingleOpModel() = delete;
SingleOpModel(Operation&& operation,
const std::vector<TensorRefFloat32>& inputs,
const std::vector<TensorRefFloat32>& outputs);
const std::vector<TensorRef<BHWC>>& inputs,
const std::vector<TensorRef<BHWC>>& outputs);
virtual ~SingleOpModel() = default;

View File

@ -31,7 +31,7 @@ namespace gl {
namespace {
TEST(TransposeConvTest, O2H2W1I1Stride1x1DAdjacent1x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
@ -54,7 +54,7 @@ TEST(TransposeConvTest, O2H2W1I1Stride1x1DAdjacent1x1) {
attr.adjacent = HW(1, 1);
attr.stride = HW(1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 2, 2, 2);
@ -69,7 +69,7 @@ TEST(TransposeConvTest, O2H2W1I1Stride1x1DAdjacent1x1) {
}
TEST(TransposeConvTest, O1H2W2I1Stride1x1Adjacent2x2) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 3, 1);
@ -92,7 +92,7 @@ TEST(TransposeConvTest, O1H2W2I1Stride1x1Adjacent2x2) {
attr.padding.appended = HW(0, 0);
attr.stride = HW(1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 1, 1, 1);
@ -106,7 +106,7 @@ TEST(TransposeConvTest, O1H2W2I1Stride1x1Adjacent2x2) {
}
TEST(TransposeConvTest, O1H3W3I1Stride1x1Adjacent1x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
@ -129,7 +129,7 @@ TEST(TransposeConvTest, O1H3W3I1Stride1x1Adjacent1x1) {
attr.padding.appended = HW(0, 0);
attr.stride = HW(1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 1, 1, 1);
@ -143,7 +143,7 @@ TEST(TransposeConvTest, O1H3W3I1Stride1x1Adjacent1x1) {
}
TEST(TransposeConvTest, O2H1W1I2Stride1x1Dilation1x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 1, 2);
@ -166,7 +166,7 @@ TEST(TransposeConvTest, O2H1W1I2Stride1x1Dilation1x1) {
attr.padding.appended = HW(0, 0);
attr.stride = HW(1, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 2, 1, 2);
@ -180,7 +180,7 @@ TEST(TransposeConvTest, O2H1W1I2Stride1x1Dilation1x1) {
}
TEST(TransposeConvTest, O1H1W1I1Stride2x2Dilation1x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 3, 1);
@ -204,7 +204,7 @@ TEST(TransposeConvTest, O1H1W1I1Stride2x2Dilation1x1) {
attr.padding.appended = HW(0, 0);
attr.stride = HW(2, 2);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 3;
output.shape = BHWC(1, 1, 1, 1);

View File

@ -31,12 +31,12 @@ namespace gl {
namespace {
TEST(UpsamplingBilinearTest, 1x1x2To2x2x2) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 1, 2);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 2);
@ -56,12 +56,12 @@ TEST(UpsamplingBilinearTest, 1x1x2To2x2x2) {
}
TEST(UpsamplingBilinearTest, 1x2x1To1x4x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 4, 1);
@ -80,12 +80,12 @@ TEST(UpsamplingBilinearTest, 1x2x1To1x4x1) {
}
TEST(UpsamplingBilinearTest, 2x2x1To4x4x1) {
TensorRefFloat32 input;
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output;
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 4, 4, 1);

View File

@ -32,7 +32,7 @@ Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor,
return CreateReadOnlyShaderStorageBuffer<float>(transposed, gl_buffer);
}
Status CreatePHWC4BufferFromTensorRef(const TensorRefFloat32& tensor_ref,
Status CreatePHWC4BufferFromTensorRef(const TensorRef<BHWC>& tensor_ref,
GlBuffer* gl_buffer) {
return CreateReadWriteShaderStorageBuffer<float>(
GetElementsSizeForPHWC4(tensor_ref.shape), gl_buffer);

View File

@ -72,7 +72,7 @@ Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor,
// Creates read-write buffer for the given tensor shape, where data layout is
// supposed to be PHWC4.
Status CreatePHWC4BufferFromTensorRef(const TensorRefFloat32& tensor_ref,
Status CreatePHWC4BufferFromTensorRef(const TensorRef<BHWC>& tensor_ref,
GlBuffer* gl_buffer);
// Copies data from a buffer that holds data in PHWC4 layout to the given

View File

@ -148,7 +148,7 @@ class Delegate {
// TODO(impjdi): Remove code duplication.
auto values = graph.values();
auto find_value = [&](int tensor_index) -> Value<TensorRefFloat32>* {
auto find_value = [&](int tensor_index) -> Value<TensorRef<BHWC>>* {
for (auto value : values) {
if (value->tensor.ref == tensor_index) return value;
}