Get rid of TensorRefFloat32 specialization. It no longer keeps type in template params, therefore, it is better to use ref explicitly.

PiperOrigin-RevId: 253738844
This commit is contained in:
A. Unique TensorFlower 2019-06-18 00:38:53 -07:00 committed by TensorFlower Gardener
parent 3557034977
commit 5044a18831
34 changed files with 207 additions and 207 deletions

View File

@ -68,12 +68,12 @@ int64_t DimensionsProduct(const TfLiteIntArray& dims) {
// will turn into: // will turn into:
// node(copy(output)) <- passthrough_node(output) // node(copy(output)) <- passthrough_node(output)
Status NewPassthroughNode(GraphFloat32* graph, Node* node, Status NewPassthroughNode(GraphFloat32* graph, Node* node,
const Value<TensorRefFloat32>* output, const Value<TensorRef<BHWC>>* output,
Node** passthru_node) { Node** passthru_node) {
*passthru_node = graph->NewNode(); *passthru_node = graph->NewNode();
// Make copies for every output in the original node. // Make copies for every output in the original node.
RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id)); RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id));
Value<TensorRefFloat32>* copy_output = graph->NewValue(); Value<TensorRef<BHWC>>* copy_output = graph->NewValue();
RETURN_IF_ERROR(graph->SetProducer(node->id, copy_output->id)); RETURN_IF_ERROR(graph->SetProducer(node->id, copy_output->id));
RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id)); RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id));
copy_output->tensor = output->tensor; copy_output->tensor = output->tensor;
@ -265,13 +265,13 @@ class ObjectReader {
public: public:
ObjectReader(GraphFloat32* graph, TfLiteContext* context, ObjectReader(GraphFloat32* graph, TfLiteContext* context,
const TfLiteNode* tflite_node, const TfLiteNode* tflite_node,
std::vector<Value<TensorRefFloat32>*>* tensor_to_value) std::vector<Value<TensorRef<BHWC>>*>* tensor_to_value)
: graph_(graph), : graph_(graph),
context_(context), context_(context),
tflite_node_(tflite_node), tflite_node_(tflite_node),
tensor_to_value_(tensor_to_value) {} tensor_to_value_(tensor_to_value) {}
Status ReadValue(uint32_t idx, Value<TensorRefFloat32>** value) { Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value) {
if (idx >= tflite_node_->inputs->size) { if (idx >= tflite_node_->inputs->size) {
return OutOfRangeError(StrCat("ReadValue: input tensor index: ", idx)); return OutOfRangeError(StrCat("ReadValue: input tensor index: ", idx));
} }
@ -319,7 +319,7 @@ class ObjectReader {
tflite_node_->outputs->size)); tflite_node_->outputs->size));
} }
int output_tensor_idx = tflite_node_->outputs->data[id]; int output_tensor_idx = tflite_node_->outputs->data[id];
Value<TensorRefFloat32>* value; Value<TensorRef<BHWC>>* value;
RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value)); RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value));
RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id)); RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id));
return OkStatus(); return OkStatus();
@ -333,13 +333,13 @@ class ObjectReader {
} }
Status AddInput(const Node* node, uint32_t idx) { Status AddInput(const Node* node, uint32_t idx) {
Value<TensorRefFloat32>* input; Value<TensorRef<BHWC>>* input;
RETURN_IF_ERROR(ReadValue(idx, &input)); RETURN_IF_ERROR(ReadValue(idx, &input));
return graph_->AddConsumer(node->id, input->id); return graph_->AddConsumer(node->id, input->id);
} }
Status ReadValueByTensorIdx(uint32_t tensor_idx, Status ReadValueByTensorIdx(uint32_t tensor_idx,
Value<TensorRefFloat32>** value) { Value<TensorRef<BHWC>>** value) {
if (tensor_idx >= tensor_to_value_->size()) { if (tensor_idx >= tensor_to_value_->size()) {
return OutOfRangeError( return OutOfRangeError(
StrCat("ReadValue: input tensor index: ", tensor_idx)); StrCat("ReadValue: input tensor index: ", tensor_idx));
@ -350,7 +350,7 @@ class ObjectReader {
return NotFoundError( return NotFoundError(
StrCat("ReadValue: value is a constant tensor: ", tensor_idx)); StrCat("ReadValue: value is a constant tensor: ", tensor_idx));
} }
Value<TensorRefFloat32>* value = graph_->NewValue(); Value<TensorRef<BHWC>>* value = graph_->NewValue();
RETURN_IF_ERROR( RETURN_IF_ERROR(
ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor)); ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor));
value->tensor.ref = tensor_idx; value->tensor.ref = tensor_idx;
@ -364,7 +364,7 @@ class ObjectReader {
GraphFloat32* graph_ = nullptr; GraphFloat32* graph_ = nullptr;
const TfLiteContext* context_ = nullptr; const TfLiteContext* context_ = nullptr;
const TfLiteNode* tflite_node_ = nullptr; const TfLiteNode* tflite_node_ = nullptr;
std::vector<Value<TensorRefFloat32>*>* tensor_to_value_; std::vector<Value<TensorRef<BHWC>>*>* tensor_to_value_;
}; };
Status CheckInputsOutputs(const TfLiteContext* context, Status CheckInputsOutputs(const TfLiteContext* context,
@ -639,7 +639,7 @@ class Conv2DOperationParser : public TFLiteOperationParser {
// Creates a simple node that holds tensor value. // Creates a simple node that holds tensor value.
Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, Status NewConstNode(TensorFloat32 t, GraphFloat32* graph,
Value<TensorRefFloat32>** value) { Value<TensorRef<BHWC>>** value) {
ConstTensorAttributes attr; ConstTensorAttributes attr;
attr.tensor = std::move(t); attr.tensor = std::move(t);
Node* node = graph->NewNode(); Node* node = graph->NewNode();
@ -677,16 +677,16 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
ConcatAttributes attr; ConcatAttributes attr;
// Read inputs first to make sure const node is added to a graph before // Read inputs first to make sure const node is added to a graph before
// concat node to ensure topological order. // concat node to ensure topological order.
std::vector<const Value<TensorRefFloat32>*> inputs; std::vector<const Value<TensorRef<BHWC>>*> inputs;
for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) { for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
Value<TensorRefFloat32>* value; Value<TensorRef<BHWC>>* value;
const auto status = reader->ReadValue(idx, &value); const auto status = reader->ReadValue(idx, &value);
if (status.ok()) { if (status.ok()) {
inputs.push_back(value); inputs.push_back(value);
} else { } else {
TensorFloat32 tensor; TensorFloat32 tensor;
RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor)); RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
Value<TensorRefFloat32>* value; Value<TensorRef<BHWC>>* value;
RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value)); RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
inputs.push_back(value); inputs.push_back(value);
} }
@ -695,7 +695,7 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
Node* node = graph->NewNode(); Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONCAT); node->operation.type = ToString(OperationType::CONCAT);
RETURN_IF_ERROR(reader->AddOutputs(node)); RETURN_IF_ERROR(reader->AddOutputs(node));
for (const Value<TensorRefFloat32>* input : inputs) { for (const Value<TensorRef<BHWC>>* input : inputs) {
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
} }
@ -1143,11 +1143,11 @@ class LstmOperationParser : public TFLiteOperationParser {
lstm_attr.kernel_type = LstmKernelType::BASIC; lstm_attr.kernel_type = LstmKernelType::BASIC;
lstm_node->operation.attributes = lstm_attr; lstm_node->operation.attributes = lstm_attr;
Value<TensorRefFloat32>* concat_temp; Value<TensorRef<BHWC>>* concat_temp;
int concat_tensor_idx = tflite_node->outputs->data[2]; int concat_tensor_idx = tflite_node->outputs->data[2];
RETURN_IF_ERROR( RETURN_IF_ERROR(
reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp)); reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
Value<TensorRefFloat32>* activ_temp; Value<TensorRef<BHWC>>* activ_temp;
int activ_tensor_idx = tflite_node->outputs->data[3]; int activ_tensor_idx = tflite_node->outputs->data[3];
RETURN_IF_ERROR( RETURN_IF_ERROR(
reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp)); reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
@ -1521,7 +1521,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser {
if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) { if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
auto& reshape = node; auto& reshape = node;
conv = graph->NewNode(); // reset conv pointer! conv = graph->NewNode(); // reset conv pointer!
Value<TensorRefFloat32>* reshaped_value = graph->NewValue(); Value<TensorRef<BHWC>>* reshaped_value = graph->NewValue();
reshaped_value->tensor.shape = BHWC(1, 1, 1, weights.shape.w); reshaped_value->tensor.shape = BHWC(1, 1, 1, weights.shape.w);
RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id)); RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id));
reshape->operation.type = ToString(OperationType::RESHAPE); reshape->operation.type = ToString(OperationType::RESHAPE);
@ -1558,7 +1558,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
Node* node = graph->NewNode(); Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::SLICE); node->operation.type = ToString(OperationType::SLICE);
RETURN_IF_ERROR(reader->AddOutputs(node)); RETURN_IF_ERROR(reader->AddOutputs(node));
Value<TensorRefFloat32>* input; Value<TensorRef<BHWC>>* input;
RETURN_IF_ERROR(reader->ReadValue(0, &input)); RETURN_IF_ERROR(reader->ReadValue(0, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
@ -1721,7 +1721,7 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
ObjectReader* reader) final { ObjectReader* reader) final {
auto* node = graph->NewNode(); auto* node = graph->NewNode();
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
Value<TensorRefFloat32>* input; Value<TensorRef<BHWC>>* input;
RETURN_IF_ERROR(reader->ReadValue(2, &input)); RETURN_IF_ERROR(reader->ReadValue(2, &input));
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
RETURN_IF_ERROR(reader->AddOutputs(node)); RETURN_IF_ERROR(reader->AddOutputs(node));
@ -1970,7 +1970,7 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
} // namespace } // namespace
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
TensorRefFloat32* tensor_ref) { TensorRef<BHWC>* tensor_ref) {
tensor_ref->type = ToDataType(tflite_tensor.type); tensor_ref->type = ToDataType(tflite_tensor.type);
const TfLiteIntArray* dims = tflite_tensor.dims; const TfLiteIntArray* dims = tflite_tensor.dims;
switch (dims->size) { switch (dims->size) {
@ -2128,7 +2128,7 @@ Status BuildModel(TfLiteContext* context,
} }
operations.push_back(std::move(op_parser)); operations.push_back(std::move(op_parser));
} }
std::vector<Value<TensorRefFloat32>*> tensor_to_value(context->tensors_size, std::vector<Value<TensorRef<BHWC>>*> tensor_to_value(context->tensors_size,
nullptr); nullptr);
for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) { for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
TfLiteNode* tflite_node = nullptr; TfLiteNode* tflite_node = nullptr;

View File

@ -38,7 +38,7 @@ Status BuildModel(TfLiteContext* context,
// Module-internal converter, exposed for unit testing purpose only. // Module-internal converter, exposed for unit testing purpose only.
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
TensorRefFloat32* tensor_ref); TensorRef<BHWC>* tensor_ref);
} // namespace gpu } // namespace gpu
} // namespace tflite } // namespace tflite

View File

@ -34,7 +34,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank0) {
tflite_tensor.type = TfLiteType::kTfLiteFloat32; tflite_tensor.type = TfLiteType::kTfLiteFloat32;
tflite_tensor.dims = TfLiteIntArrayCreate(1); tflite_tensor.dims = TfLiteIntArrayCreate(1);
tflite_tensor.dims->data[0] = 4; tflite_tensor.dims->data[0] = 4;
TensorRefFloat32 tensor_ref; TensorRef<BHWC> tensor_ref;
const auto status = const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref); ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims); TfLiteIntArrayFree(tflite_tensor.dims);
@ -49,7 +49,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank1) {
tflite_tensor.dims = TfLiteIntArrayCreate(2); tflite_tensor.dims = TfLiteIntArrayCreate(2);
tflite_tensor.dims->data[0] = 4; tflite_tensor.dims->data[0] = 4;
tflite_tensor.dims->data[1] = 5; tflite_tensor.dims->data[1] = 5;
TensorRefFloat32 tensor_ref; TensorRef<BHWC> tensor_ref;
const auto status = const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref); ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims); TfLiteIntArrayFree(tflite_tensor.dims);
@ -65,7 +65,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank2) {
tflite_tensor.dims->data[0] = 4; tflite_tensor.dims->data[0] = 4;
tflite_tensor.dims->data[1] = 5; tflite_tensor.dims->data[1] = 5;
tflite_tensor.dims->data[2] = 6; tflite_tensor.dims->data[2] = 6;
TensorRefFloat32 tensor_ref; TensorRef<BHWC> tensor_ref;
const auto status = const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref); ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims); TfLiteIntArrayFree(tflite_tensor.dims);
@ -82,7 +82,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank3) {
tflite_tensor.dims->data[1] = 5; tflite_tensor.dims->data[1] = 5;
tflite_tensor.dims->data[2] = 6; tflite_tensor.dims->data[2] = 6;
tflite_tensor.dims->data[3] = 7; tflite_tensor.dims->data[3] = 7;
TensorRefFloat32 tensor_ref; TensorRef<BHWC> tensor_ref;
const auto status = const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref); ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims); TfLiteIntArrayFree(tflite_tensor.dims);
@ -95,7 +95,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankLT0) {
TfLiteTensor tflite_tensor; TfLiteTensor tflite_tensor;
tflite_tensor.type = TfLiteType::kTfLiteFloat32; tflite_tensor.type = TfLiteType::kTfLiteFloat32;
tflite_tensor.dims = TfLiteIntArrayCreate(0); tflite_tensor.dims = TfLiteIntArrayCreate(0);
TensorRefFloat32 tensor_ref; TensorRef<BHWC> tensor_ref;
const auto status = const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref); ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims); TfLiteIntArrayFree(tflite_tensor.dims);
@ -107,7 +107,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankGT3) {
TfLiteTensor tflite_tensor; TfLiteTensor tflite_tensor;
tflite_tensor.type = TfLiteType::kTfLiteFloat32; tflite_tensor.type = TfLiteType::kTfLiteFloat32;
tflite_tensor.dims = TfLiteIntArrayCreate(5); tflite_tensor.dims = TfLiteIntArrayCreate(5);
TensorRefFloat32 tensor_ref; TensorRef<BHWC> tensor_ref;
const auto status = const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref); ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims); TfLiteIntArrayFree(tflite_tensor.dims);

View File

@ -31,8 +31,8 @@ TEST(Model, SingleNode) {
// graph_input -> node -> graph_output // graph_input -> node -> graph_output
GraphFloat32 graph; GraphFloat32 graph;
Node* node = graph.NewNode(); Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue(); Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok()); ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok()); ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
@ -52,9 +52,9 @@ TEST(Model, SingleNodeMultipleOutputs) {
// graph_input -> node -> (graph_output1, graph_output2) // graph_input -> node -> (graph_output1, graph_output2)
GraphFloat32 graph; GraphFloat32 graph;
Node* node = graph.NewNode(); Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue(); Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output1 = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output1 = graph.NewValue();
Value<TensorRefFloat32>* graph_output2 = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output2 = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok()); ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node->id, graph_output1->id).ok()); ASSERT_TRUE(graph.SetProducer(node->id, graph_output1->id).ok());
ASSERT_TRUE(graph.SetProducer(node->id, graph_output2->id).ok()); ASSERT_TRUE(graph.SetProducer(node->id, graph_output2->id).ok());
@ -67,7 +67,7 @@ TEST(Model, SingleNodeMultipleOutputs) {
TEST(Model, SetSameConsumer) { TEST(Model, SetSameConsumer) {
GraphFloat32 graph; GraphFloat32 graph;
Node* node = graph.NewNode(); Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue(); Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok()); ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
EXPECT_FALSE(graph.AddConsumer(node->id, graph_input->id).ok()); EXPECT_FALSE(graph.AddConsumer(node->id, graph_input->id).ok());
} }
@ -76,8 +76,8 @@ TEST(Model, RemoveConsumer) {
// (graph_input1, graph_input2) -> node // (graph_input1, graph_input2) -> node
GraphFloat32 graph; GraphFloat32 graph;
Node* node = graph.NewNode(); Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_input1 = graph.NewValue(); Value<TensorRef<BHWC>>* graph_input1 = graph.NewValue();
Value<TensorRefFloat32>* graph_input2 = graph.NewValue(); Value<TensorRef<BHWC>>* graph_input2 = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input1->id).ok()); ASSERT_TRUE(graph.AddConsumer(node->id, graph_input1->id).ok());
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input2->id).ok()); ASSERT_TRUE(graph.AddConsumer(node->id, graph_input2->id).ok());
EXPECT_THAT(graph.FindConsumers(graph_input1->id), EXPECT_THAT(graph.FindConsumers(graph_input1->id),
@ -101,7 +101,7 @@ TEST(Model, RemoveConsumer) {
TEST(Model, SetSameProducer) { TEST(Model, SetSameProducer) {
GraphFloat32 graph; GraphFloat32 graph;
Node* node = graph.NewNode(); Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_output = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok()); ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
EXPECT_FALSE(graph.SetProducer(node->id, graph_output->id).ok()); EXPECT_FALSE(graph.SetProducer(node->id, graph_output->id).ok());
} }
@ -109,7 +109,7 @@ TEST(Model, SetSameProducer) {
TEST(Model, RemoveProducer) { TEST(Model, RemoveProducer) {
GraphFloat32 graph; GraphFloat32 graph;
Node* node = graph.NewNode(); Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_output = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok()); ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
EXPECT_THAT(graph.inputs(), UnorderedElementsAre()); EXPECT_THAT(graph.inputs(), UnorderedElementsAre());
@ -126,8 +126,8 @@ TEST(Model, RemoveProducer) {
TEST(Model, RemoveSimpleNodeDegenerateCase) { TEST(Model, RemoveSimpleNodeDegenerateCase) {
GraphFloat32 graph; GraphFloat32 graph;
Node* node = graph.NewNode(); Node* node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue(); Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok()); ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok()); ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
@ -145,9 +145,9 @@ TEST(Model, RemoveSimpleNodeNoPreviousNode) {
GraphFloat32 graph; GraphFloat32 graph;
Node* simple_node = graph.NewNode(); Node* simple_node = graph.NewNode();
Node* consumer_node = graph.NewNode(); Node* consumer_node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue(); Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRefFloat32>* value = graph.NewValue(); Value<TensorRef<BHWC>>* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(simple_node->id, graph_input->id).ok()); ASSERT_TRUE(graph.AddConsumer(simple_node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(simple_node->id, value->id).ok()); ASSERT_TRUE(graph.SetProducer(simple_node->id, value->id).ok());
@ -167,9 +167,9 @@ TEST(Model, RemoveSimpleNodeNoAfterNodes) {
GraphFloat32 graph; GraphFloat32 graph;
Node* simple_node = graph.NewNode(); Node* simple_node = graph.NewNode();
Node* producer_node = graph.NewNode(); Node* producer_node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue(); Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRefFloat32>* value = graph.NewValue(); Value<TensorRef<BHWC>>* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(simple_node->id, value->id).ok()); ASSERT_TRUE(graph.AddConsumer(simple_node->id, value->id).ok());
ASSERT_TRUE(graph.SetProducer(simple_node->id, graph_output->id).ok()); ASSERT_TRUE(graph.SetProducer(simple_node->id, graph_output->id).ok());
@ -190,10 +190,10 @@ TEST(Model, RemoveSimpleNodeGeneralCase) {
Node* simple_node = graph.NewNode(); Node* simple_node = graph.NewNode();
Node* producer_node = graph.NewNode(); Node* producer_node = graph.NewNode();
Node* consumer_node = graph.NewNode(); Node* consumer_node = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue(); Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRefFloat32>* value0 = graph.NewValue(); Value<TensorRef<BHWC>>* value0 = graph.NewValue();
Value<TensorRefFloat32>* value1 = graph.NewValue(); Value<TensorRef<BHWC>>* value1 = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(producer_node->id, graph_input->id).ok()); ASSERT_TRUE(graph.AddConsumer(producer_node->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(producer_node->id, value0->id).ok()); ASSERT_TRUE(graph.SetProducer(producer_node->id, value0->id).ok());
@ -217,14 +217,14 @@ TEST(Model, CircularDependency) {
{ {
GraphFloat32 graph; GraphFloat32 graph;
Node* node = graph.NewNode(); Node* node = graph.NewNode();
Value<TensorRefFloat32>* value = graph.NewValue(); Value<TensorRef<BHWC>>* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node->id, value->id).ok()); ASSERT_TRUE(graph.AddConsumer(node->id, value->id).ok());
EXPECT_FALSE(graph.SetProducer(node->id, value->id).ok()); EXPECT_FALSE(graph.SetProducer(node->id, value->id).ok());
} }
{ {
GraphFloat32 graph; GraphFloat32 graph;
Node* node = graph.NewNode(); Node* node = graph.NewNode();
Value<TensorRefFloat32>* value = graph.NewValue(); Value<TensorRef<BHWC>>* value = graph.NewValue();
ASSERT_TRUE(graph.SetProducer(node->id, value->id).ok()); ASSERT_TRUE(graph.SetProducer(node->id, value->id).ok());
EXPECT_FALSE(graph.AddConsumer(node->id, value->id).ok()); EXPECT_FALSE(graph.AddConsumer(node->id, value->id).ok());
} }
@ -237,8 +237,8 @@ TEST(Model, ReassignValue) {
GraphFloat32 graph; GraphFloat32 graph;
Node* node1 = graph.NewNode(); Node* node1 = graph.NewNode();
Node* node2 = graph.NewNode(); Node* node2 = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue(); Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok()); ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node1->id, graph_output->id).ok()); ASSERT_TRUE(graph.SetProducer(node1->id, graph_output->id).ok());
ASSERT_TRUE(graph.AddConsumer(node2->id, graph_input->id).ok()); ASSERT_TRUE(graph.AddConsumer(node2->id, graph_input->id).ok());
@ -264,9 +264,9 @@ TEST(Model, DeleteValue) {
GraphFloat32 graph; GraphFloat32 graph;
Node* node1 = graph.NewNode(); Node* node1 = graph.NewNode();
Node* node2 = graph.NewNode(); Node* node2 = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue(); Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRefFloat32>* value = graph.NewValue(); Value<TensorRef<BHWC>>* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok()); ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok()); ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok()); ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
@ -305,10 +305,10 @@ TEST(Model, DeleteNode) {
Node* node1 = graph.NewNode(); Node* node1 = graph.NewNode();
Node* node2 = graph.NewNode(); Node* node2 = graph.NewNode();
Node* node3 = graph.NewNode(); Node* node3 = graph.NewNode();
Value<TensorRefFloat32>* graph_input = graph.NewValue(); Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
Value<TensorRefFloat32>* graph_output = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
Value<TensorRefFloat32>* graph_output2 = graph.NewValue(); Value<TensorRef<BHWC>>* graph_output2 = graph.NewValue();
Value<TensorRefFloat32>* value = graph.NewValue(); Value<TensorRef<BHWC>>* value = graph.NewValue();
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok()); ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok()); ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok()); ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());

View File

@ -57,11 +57,11 @@ TEST(MergeConvolutionWithAddTest, Smoke) {
ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok()); ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok());
Value<TensorRefFloat32>* output; Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16); output->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRefFloat32>* link1; Value<TensorRef<BHWC>>* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, add_node, &link1).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, add_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16); link1->tensor.shape = BHWC(1, 4, 4, 16);
@ -108,11 +108,11 @@ TEST(MergeAddWithConvolutionTest, Smoke) {
ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok()); ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
Value<TensorRefFloat32>* output; Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok()); ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16); output->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRefFloat32>* link1; Value<TensorRef<BHWC>>* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, add_node, conv_node, &link1).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, add_node, conv_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16); link1->tensor.shape = BHWC(1, 4, 4, 16);

View File

@ -58,11 +58,11 @@ TEST(MergeConvolutionWithMulTest, Smoke) {
ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok()); ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok());
Value<TensorRefFloat32>* output; Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, mul_node, &output).ok()); ASSERT_TRUE(AddOutput(&graph, mul_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16); output->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRefFloat32>* link1; Value<TensorRef<BHWC>>* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, mul_node, &link1).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, mul_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16); link1->tensor.shape = BHWC(1, 4, 4, 16);
@ -109,11 +109,11 @@ TEST(MergeMulWithConvolutionTest, Smoke) {
ASSERT_TRUE(graph.AddConsumer(mul_node->id, input->id).ok()); ASSERT_TRUE(graph.AddConsumer(mul_node->id, input->id).ok());
Value<TensorRefFloat32>* output; Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok()); ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
output->tensor.shape = BHWC(1, 4, 4, 16); output->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRefFloat32>* link1; Value<TensorRef<BHWC>>* link1;
ASSERT_TRUE(ConnectTwoNodes(&graph, mul_node, conv_node, &link1).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, mul_node, conv_node, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16); link1->tensor.shape = BHWC(1, 4, 4, 16);

View File

@ -68,16 +68,16 @@ TEST(MakeFullyConnected, Smoke) {
ASSERT_TRUE(graph.AddConsumer(conv1x1_node0->id, input->id).ok()); ASSERT_TRUE(graph.AddConsumer(conv1x1_node0->id, input->id).ok());
Value<TensorRefFloat32>* output; Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, conv1x1_node2, &output).ok()); ASSERT_TRUE(AddOutput(&graph, conv1x1_node2, &output).ok());
output->tensor.shape = BHWC(1, 1, 1, 32); output->tensor.shape = BHWC(1, 1, 1, 32);
Value<TensorRefFloat32>* link1; Value<TensorRef<BHWC>>* link1;
ASSERT_TRUE( ASSERT_TRUE(
ConnectTwoNodes(&graph, conv1x1_node0, conv4x4_node1, &link1).ok()); ConnectTwoNodes(&graph, conv1x1_node0, conv4x4_node1, &link1).ok());
link1->tensor.shape = BHWC(1, 4, 4, 16); link1->tensor.shape = BHWC(1, 4, 4, 16);
Value<TensorRefFloat32>* link2; Value<TensorRef<BHWC>>* link2;
ASSERT_TRUE( ASSERT_TRUE(
ConnectTwoNodes(&graph, conv4x4_node1, conv1x1_node2, &link2).ok()); ConnectTwoNodes(&graph, conv4x4_node1, conv1x1_node2, &link2).ok());
link2->tensor.shape = BHWC(1, 1, 1, 16); link2->tensor.shape = BHWC(1, 1, 1, 16);

View File

@ -38,7 +38,7 @@ TEST(MakePadding, Smoke) {
attr.axis = Axis::HEIGHT; attr.axis = Axis::HEIGHT;
concat_node->operation.attributes = attr; concat_node->operation.attributes = attr;
Value<TensorRefFloat32>* output; Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, concat_node, &output).ok()); ASSERT_TRUE(AddOutput(&graph, concat_node, &output).ok());
output->tensor.shape = BHWC(1, 7, 3, 5); output->tensor.shape = BHWC(1, 7, 3, 5);
@ -50,7 +50,7 @@ TEST(MakePadding, Smoke) {
std::vector<float>(const_attr.tensor.shape.DimensionsProduct(), 0); std::vector<float>(const_attr.tensor.shape.DimensionsProduct(), 0);
const_node->operation.attributes = const_attr; const_node->operation.attributes = const_attr;
Value<TensorRefFloat32>* const_link; Value<TensorRef<BHWC>>* const_link;
ASSERT_TRUE( ASSERT_TRUE(
ConnectTwoNodes(&graph, const_node, concat_node, &const_link).ok()); ConnectTwoNodes(&graph, const_node, concat_node, &const_link).ok());
const_link->tensor.shape = const_attr.tensor.shape; const_link->tensor.shape = const_attr.tensor.shape;

View File

@ -62,15 +62,15 @@ TEST(MatchDilatedConvolutionTest, MakesDilatedConvolution) {
ASSERT_TRUE(graph.AddConsumer(sb_node->id, input->id).ok()); ASSERT_TRUE(graph.AddConsumer(sb_node->id, input->id).ok());
Value<TensorRefFloat32>* output; Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, bs_node, &output).ok()); ASSERT_TRUE(AddOutput(&graph, bs_node, &output).ok());
output->tensor.shape = BHWC(1, 95, 1, 17); output->tensor.shape = BHWC(1, 95, 1, 17);
Value<TensorRefFloat32>* sb_link; Value<TensorRef<BHWC>>* sb_link;
ASSERT_TRUE(ConnectTwoNodes(&graph, sb_node, dw_node, &sb_link).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, sb_node, dw_node, &sb_link).ok());
sb_link->tensor.shape = BHWC(21, 128, 1, 17); sb_link->tensor.shape = BHWC(21, 128, 1, 17);
Value<TensorRefFloat32>* bs_link; Value<TensorRef<BHWC>>* bs_link;
ASSERT_TRUE(ConnectTwoNodes(&graph, dw_node, bs_node, &bs_link).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, dw_node, bs_node, &bs_link).ok());
bs_link->tensor.shape = BHWC(1, 95, 1, 17); bs_link->tensor.shape = BHWC(1, 95, 1, 17);

View File

@ -40,7 +40,7 @@ TEST(MergePaddingWith, Smoke) {
pad_node->operation.attributes = attr; pad_node->operation.attributes = attr;
auto conv_node = graph.NewNode(); auto conv_node = graph.NewNode();
Value<TensorRefFloat32>* temp; Value<TensorRef<BHWC>>* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node, conv_node, &temp).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node, conv_node, &temp).ok());
ASSERT_TRUE(AddOutput(&graph, conv_node, &temp).ok()); ASSERT_TRUE(AddOutput(&graph, conv_node, &temp).ok());
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D); conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
@ -77,7 +77,7 @@ TEST(MergePaddingWith, MergeTwo) {
pad_node1->operation.attributes = attr; pad_node1->operation.attributes = attr;
auto pad_node2 = graph.NewNode(); auto pad_node2 = graph.NewNode();
Value<TensorRefFloat32>* temp; Value<TensorRef<BHWC>>* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node1, pad_node2, &temp).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node1, pad_node2, &temp).ok());
pad_node2->operation.type = ToString(OperationType::PAD); pad_node2->operation.type = ToString(OperationType::PAD);
attr.prepended = HWC(0, 0, 0); attr.prepended = HWC(0, 0, 0);

View File

@ -33,12 +33,12 @@ TEST(RemoveSingleInputAdd, Smoke) {
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
auto add_node = graph.NewNode(); auto add_node = graph.NewNode();
Value<TensorRefFloat32>* output; Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
add_node->operation.type = ToString(OperationType::ADD); add_node->operation.type = ToString(OperationType::ADD);
add_node->operation.attributes = AddAttributes(); add_node->operation.attributes = AddAttributes();
Value<TensorRefFloat32>* temp; Value<TensorRef<BHWC>>* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
ASSERT_EQ(2, graph.nodes().size()); ASSERT_EQ(2, graph.nodes().size());
ASSERT_EQ(3, graph.values().size()); ASSERT_EQ(3, graph.values().size());
@ -61,14 +61,14 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Tensor) {
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
auto add_node = graph.NewNode(); auto add_node = graph.NewNode();
Value<TensorRefFloat32>* output; Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
add_node->operation.type = ToString(OperationType::ADD); add_node->operation.type = ToString(OperationType::ADD);
AddAttributes attr; AddAttributes attr;
attr.param = Tensor<Linear, DataType::FLOAT32>(); attr.param = Tensor<Linear, DataType::FLOAT32>();
add_node->operation.attributes = attr; add_node->operation.attributes = attr;
Value<TensorRefFloat32>* temp; Value<TensorRef<BHWC>>* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
ASSERT_EQ(2, graph.nodes().size()); ASSERT_EQ(2, graph.nodes().size());
ASSERT_EQ(3, graph.values().size()); ASSERT_EQ(3, graph.values().size());
@ -90,11 +90,11 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Multiple) {
ASSERT_TRUE(graph.AddConsumer(node_b->id, input->id).ok()); ASSERT_TRUE(graph.AddConsumer(node_b->id, input->id).ok());
auto add_node = graph.NewNode(); auto add_node = graph.NewNode();
Value<TensorRefFloat32>* output; Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
add_node->operation.type = ToString(OperationType::ADD); add_node->operation.type = ToString(OperationType::ADD);
Value<TensorRefFloat32>* temp; Value<TensorRef<BHWC>>* temp;
ASSERT_TRUE(ConnectTwoNodes(&graph, node_a, add_node, &temp).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, node_a, add_node, &temp).ok());
ASSERT_TRUE(ConnectTwoNodes(&graph, node_b, add_node, &temp).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, node_b, add_node, &temp).ok());
ASSERT_EQ(3, graph.nodes().size()); ASSERT_EQ(3, graph.nodes().size());
@ -115,7 +115,7 @@ TEST(RemoveDegenerateUpsampling, Smoke) {
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
auto node_to_remove = graph.NewNode(); auto node_to_remove = graph.NewNode();
Value<TensorRefFloat32>* output; Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok()); ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok());
output->tensor.shape = BHWC(1, 5, 5, 1); output->tensor.shape = BHWC(1, 5, 5, 1);
node_to_remove->operation.type = ToString(OperationType::UPSAMPLE_2D); node_to_remove->operation.type = ToString(OperationType::UPSAMPLE_2D);
@ -124,7 +124,7 @@ TEST(RemoveDegenerateUpsampling, Smoke) {
attr.type = UpsamplingType::BILINEAR; attr.type = UpsamplingType::BILINEAR;
node_to_remove->operation.attributes = attr; node_to_remove->operation.attributes = attr;
Value<TensorRefFloat32>* link; Value<TensorRef<BHWC>>* link;
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok());
link->tensor.shape = output->tensor.shape; link->tensor.shape = output->tensor.shape;
ASSERT_EQ(2, graph.nodes().size()); ASSERT_EQ(2, graph.nodes().size());
@ -148,7 +148,7 @@ TEST(RemoveIdentityReshape, Smoke) {
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
auto node_to_remove = graph.NewNode(); auto node_to_remove = graph.NewNode();
Value<TensorRefFloat32>* output; Value<TensorRef<BHWC>>* output;
ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok()); ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok());
output->tensor.shape = BHWC(1, 1, 1, 11); output->tensor.shape = BHWC(1, 1, 1, 11);
node_to_remove->operation.type = ToString(OperationType::RESHAPE); node_to_remove->operation.type = ToString(OperationType::RESHAPE);
@ -156,7 +156,7 @@ TEST(RemoveIdentityReshape, Smoke) {
attr.new_shape = BHWC(1, 1, 1, 11); attr.new_shape = BHWC(1, 1, 1, 11);
node_to_remove->operation.attributes = attr; node_to_remove->operation.attributes = attr;
Value<TensorRefFloat32>* link; Value<TensorRef<BHWC>>* link;
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok()); ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok());
link->tensor.shape = output->tensor.shape; link->tensor.shape = output->tensor.shape;
ASSERT_EQ(2, graph.nodes().size()); ASSERT_EQ(2, graph.nodes().size());

View File

@ -31,7 +31,7 @@ namespace gl {
namespace { namespace {
TEST(AddTest, TwoInputTensorsOfTheSameShape) { TEST(AddTest, TwoInputTensorsOfTheSameShape) {
TensorRefFloat32 augend, addend, output; TensorRef<BHWC> augend, addend, output;
augend.type = DataType::FLOAT32; augend.type = DataType::FLOAT32;
augend.ref = 0; augend.ref = 0;
augend.shape = BHWC(1, 2, 2, 1); augend.shape = BHWC(1, 2, 2, 1);
@ -57,7 +57,7 @@ TEST(AddTest, TwoInputTensorsOfTheSameShape) {
TEST(AddTest, InputTensorAndScalar) { TEST(AddTest, InputTensorAndScalar) {
AddAttributes attr; AddAttributes attr;
attr.param = 0.1f; attr.param = 0.1f;
TensorRefFloat32 input, output; TensorRef<BHWC> input, output;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 3, 1, 2); input.shape = BHWC(1, 3, 1, 2);
@ -75,7 +75,7 @@ TEST(AddTest, InputTensorAndScalar) {
} }
TEST(AddTest, InputTensorWithConstandBroadcast) { TEST(AddTest, InputTensorWithConstandBroadcast) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 2); input.shape = BHWC(1, 2, 2, 2);
@ -88,7 +88,7 @@ TEST(AddTest, InputTensorWithConstandBroadcast) {
tensor.data.push_back(20.0); tensor.data.push_back(20.0);
attr.param = std::move(tensor); attr.param = std::move(tensor);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 2, 2, 2); output.shape = BHWC(1, 2, 2, 2);
@ -104,19 +104,19 @@ TEST(AddTest, InputTensorWithConstandBroadcast) {
} }
TEST(AddTest, InputTensorWithRuntimeBroadcast) { TEST(AddTest, InputTensorWithRuntimeBroadcast) {
TensorRefFloat32 input1; TensorRef<BHWC> input1;
input1.type = DataType::FLOAT32; input1.type = DataType::FLOAT32;
input1.ref = 0; input1.ref = 0;
input1.shape = BHWC(1, 2, 2, 2); input1.shape = BHWC(1, 2, 2, 2);
TensorRefFloat32 input2; TensorRef<BHWC> input2;
input2.type = DataType::FLOAT32; input2.type = DataType::FLOAT32;
input2.ref = 1; input2.ref = 1;
input2.shape = BHWC(1, 1, 1, 2); input2.shape = BHWC(1, 1, 1, 2);
AddAttributes attr; AddAttributes attr;
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 2, 2, 2); output.shape = BHWC(1, 2, 2, 2);

View File

@ -31,7 +31,7 @@ namespace gl {
namespace { namespace {
TEST(ConcatTest, TwoInputTensorsByUnalignedChannel) { TEST(ConcatTest, TwoInputTensorsByUnalignedChannel) {
TensorRefFloat32 input1, input2, output; TensorRef<BHWC> input1, input2, output;
input1.type = DataType::FLOAT32; input1.type = DataType::FLOAT32;
input1.ref = 0; input1.ref = 0;
input1.shape = BHWC(1, 2, 2, 1); input1.shape = BHWC(1, 2, 2, 1);
@ -57,7 +57,7 @@ TEST(ConcatTest, TwoInputTensorsByUnalignedChannel) {
} }
TEST(ConcatTest, TwoInputTensorsByAlignedChannel) { TEST(ConcatTest, TwoInputTensorsByAlignedChannel) {
TensorRefFloat32 input1, input2, output; TensorRef<BHWC> input1, input2, output;
input1.type = DataType::FLOAT32; input1.type = DataType::FLOAT32;
input1.ref = 0; input1.ref = 0;
input1.shape = BHWC(1, 1, 1, 4); input1.shape = BHWC(1, 1, 1, 4);
@ -83,7 +83,7 @@ TEST(ConcatTest, TwoInputTensorsByAlignedChannel) {
} }
TEST(ConcatTest, TwoInputTensorsByHeight) { TEST(ConcatTest, TwoInputTensorsByHeight) {
TensorRefFloat32 input1, input2, output; TensorRef<BHWC> input1, input2, output;
input1.type = DataType::FLOAT32; input1.type = DataType::FLOAT32;
input1.ref = 0; input1.ref = 0;
input1.shape = BHWC(1, 1, 2, 1); input1.shape = BHWC(1, 1, 2, 1);
@ -109,7 +109,7 @@ TEST(ConcatTest, TwoInputTensorsByHeight) {
} }
TEST(ConcatTest, TwoInputTensorsByWidth) { TEST(ConcatTest, TwoInputTensorsByWidth) {
TensorRefFloat32 input1, input2, output; TensorRef<BHWC> input1, input2, output;
input1.type = DataType::FLOAT32; input1.type = DataType::FLOAT32;
input1.ref = 0; input1.ref = 0;
input1.shape = BHWC(1, 2, 1, 1); input1.shape = BHWC(1, 2, 1, 1);

View File

@ -31,7 +31,7 @@ namespace gl {
namespace { namespace {
TEST(ConvTest, O2H2W1I1Stride1x1Dilation1x1) { TEST(ConvTest, O2H2W1I1Stride1x1Dilation1x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
@ -54,7 +54,7 @@ TEST(ConvTest, O2H2W1I1Stride1x1Dilation1x1) {
attr.padding.appended = HW(1, 0); attr.padding.appended = HW(1, 0);
attr.strides = HW(1, 1); attr.strides = HW(1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 2, 2, 2); output.shape = BHWC(1, 2, 2, 2);
@ -69,7 +69,7 @@ TEST(ConvTest, O2H2W1I1Stride1x1Dilation1x1) {
} }
TEST(ConvTest, O1H2W2I1Stride1x1Dilation2x2) { TEST(ConvTest, O1H2W2I1Stride1x1Dilation2x2) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 3, 3, 1); input.shape = BHWC(1, 3, 3, 1);
@ -92,7 +92,7 @@ TEST(ConvTest, O1H2W2I1Stride1x1Dilation2x2) {
attr.padding.appended = HW(0, 0); attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1); attr.strides = HW(1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 1, 1, 1); output.shape = BHWC(1, 1, 1, 1);
@ -106,7 +106,7 @@ TEST(ConvTest, O1H2W2I1Stride1x1Dilation2x2) {
} }
TEST(ConvTest, O1H3W3I1Stride1x1Dilation1x1) { TEST(ConvTest, O1H3W3I1Stride1x1Dilation1x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
@ -129,7 +129,7 @@ TEST(ConvTest, O1H3W3I1Stride1x1Dilation1x1) {
attr.padding.appended = HW(0, 0); attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1); attr.strides = HW(1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 1, 1, 1); output.shape = BHWC(1, 1, 1, 1);
@ -143,7 +143,7 @@ TEST(ConvTest, O1H3W3I1Stride1x1Dilation1x1) {
} }
TEST(ConvTest, O2H1W1I2Stride1x1Dilation1x1) { TEST(ConvTest, O2H1W1I2Stride1x1Dilation1x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 1, 2); input.shape = BHWC(1, 2, 1, 2);
@ -166,7 +166,7 @@ TEST(ConvTest, O2H1W1I2Stride1x1Dilation1x1) {
attr.padding.appended = HW(0, 0); attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1); attr.strides = HW(1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 2, 1, 2); output.shape = BHWC(1, 2, 1, 2);
@ -180,7 +180,7 @@ TEST(ConvTest, O2H1W1I2Stride1x1Dilation1x1) {
} }
TEST(ConvTest, O1H1W1I1Stride2x2Dilation1x1) { TEST(ConvTest, O1H1W1I1Stride2x2Dilation1x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 3, 3, 1); input.shape = BHWC(1, 3, 3, 1);
@ -204,7 +204,7 @@ TEST(ConvTest, O1H1W1I1Stride2x2Dilation1x1) {
attr.padding.appended = HW(0, 0); attr.padding.appended = HW(0, 0);
attr.strides = HW(2, 2); attr.strides = HW(2, 2);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);

View File

@ -31,7 +31,7 @@ namespace gl {
namespace { namespace {
TEST(DepthwiseConvTest, O4H1W1I2Strides1x1Dilation1x1) { TEST(DepthwiseConvTest, O4H1W1I2Strides1x1Dilation1x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 1, 2); input.shape = BHWC(1, 1, 1, 2);
@ -55,7 +55,7 @@ TEST(DepthwiseConvTest, O4H1W1I2Strides1x1Dilation1x1) {
attr.padding.appended = HW(0, 0); attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1); attr.strides = HW(1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 1, 1, 4); output.shape = BHWC(1, 1, 1, 4);
@ -69,7 +69,7 @@ TEST(DepthwiseConvTest, O4H1W1I2Strides1x1Dilation1x1) {
} }
TEST(DepthwiseConvTest, O2H1W1I1Strides2x2Dilation1x1) { TEST(DepthwiseConvTest, O2H1W1I1Strides2x2Dilation1x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 3, 3, 1); input.shape = BHWC(1, 3, 3, 1);
@ -93,7 +93,7 @@ TEST(DepthwiseConvTest, O2H1W1I1Strides2x2Dilation1x1) {
attr.padding.appended = HW(0, 0); attr.padding.appended = HW(0, 0);
attr.strides = HW(2, 2); attr.strides = HW(2, 2);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 2, 2, 2); output.shape = BHWC(1, 2, 2, 2);
@ -108,7 +108,7 @@ TEST(DepthwiseConvTest, O2H1W1I1Strides2x2Dilation1x1) {
} }
TEST(DepthwiseConvTest, O2H2W2I1Strides1x1Dilation2x2) { TEST(DepthwiseConvTest, O2H2W2I1Strides1x1Dilation2x2) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 3, 3, 1); input.shape = BHWC(1, 3, 3, 1);
@ -132,7 +132,7 @@ TEST(DepthwiseConvTest, O2H2W2I1Strides1x1Dilation2x2) {
attr.padding.appended = HW(0, 0); attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1); attr.strides = HW(1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 1, 1, 2); output.shape = BHWC(1, 1, 1, 2);

View File

@ -33,8 +33,8 @@ class ElementwiseOneArgumentTest : public ::testing::Test {
ElementwiseOneArgumentTest() = default; ElementwiseOneArgumentTest() = default;
~ElementwiseOneArgumentTest() override = default; ~ElementwiseOneArgumentTest() override = default;
TensorRefFloat32 GetTensorRef(int ref) { TensorRef<BHWC> GetTensorRef(int ref) {
TensorRefFloat32 tensor_ref; TensorRef<BHWC> tensor_ref;
tensor_ref.type = DataType::FLOAT32; tensor_ref.type = DataType::FLOAT32;
tensor_ref.ref = ref; tensor_ref.ref = ref;
tensor_ref.shape = BHWC(1, 2, 2, 1); tensor_ref.shape = BHWC(1, 2, 2, 1);
@ -137,8 +137,8 @@ class ElementwiseTwoArgumentsTest : public ::testing::Test {
ElementwiseTwoArgumentsTest() = default; ElementwiseTwoArgumentsTest() = default;
~ElementwiseTwoArgumentsTest() override = default; ~ElementwiseTwoArgumentsTest() override = default;
TensorRefFloat32 GetTensorRef(int ref) { TensorRef<BHWC> GetTensorRef(int ref) {
TensorRefFloat32 tensor_ref; TensorRef<BHWC> tensor_ref;
tensor_ref.type = DataType::FLOAT32; tensor_ref.type = DataType::FLOAT32;
tensor_ref.ref = ref; tensor_ref.ref = ref;
tensor_ref.shape = BHWC(1, 2, 2, 1); tensor_ref.shape = BHWC(1, 2, 2, 1);

View File

@ -31,7 +31,7 @@ namespace gl {
namespace { namespace {
TEST(FullyConnectedTest, MatrixByVectorMultiplication) { TEST(FullyConnectedTest, MatrixByVectorMultiplication) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 1, 2); input.shape = BHWC(1, 1, 1, 2);
@ -50,7 +50,7 @@ TEST(FullyConnectedTest, MatrixByVectorMultiplication) {
weights.data = {1, 2, 3, 4, 5, 6, 7, 8}; weights.data = {1, 2, 3, 4, 5, 6, 7, 8};
attr.weights = std::move(weights); attr.weights = std::move(weights);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 1, 1, 4); output.shape = BHWC(1, 1, 1, 4);

View File

@ -31,22 +31,22 @@ namespace gl {
namespace { namespace {
TEST(LstmTest, Input2x2x1) { TEST(LstmTest, Input2x2x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 prev_state; TensorRef<BHWC> prev_state;
prev_state.type = DataType::FLOAT32; prev_state.type = DataType::FLOAT32;
prev_state.ref = 1; prev_state.ref = 1;
prev_state.shape = BHWC(1, 2, 2, 1); prev_state.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output_state; TensorRef<BHWC> output_state;
output_state.type = DataType::FLOAT32; output_state.type = DataType::FLOAT32;
output_state.ref = 2; output_state.ref = 2;
output_state.shape = BHWC(1, 2, 2, 1); output_state.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output_activation; TensorRef<BHWC> output_activation;
output_activation.type = DataType::FLOAT32; output_activation.type = DataType::FLOAT32;
output_activation.ref = 3; output_activation.ref = 3;
output_activation.shape = BHWC(1, 2, 2, 1); output_activation.shape = BHWC(1, 2, 2, 1);

View File

@ -31,17 +31,17 @@ namespace gl {
namespace { namespace {
TEST(MaxUnpoolingTest, Kernel2x2Stride2x2) { TEST(MaxUnpoolingTest, Kernel2x2Stride2x2) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 indices; TensorRef<BHWC> indices;
indices.type = DataType::INT32; indices.type = DataType::INT32;
indices.ref = 1; indices.ref = 1;
indices.shape = BHWC(1, 2, 2, 1); indices.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 4, 4, 1); output.shape = BHWC(1, 4, 4, 1);

View File

@ -31,12 +31,12 @@ namespace gl {
namespace { namespace {
TEST(MulTest, Scalar) { TEST(MulTest, Scalar) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);
@ -51,12 +51,12 @@ TEST(MulTest, Scalar) {
} }
TEST(MulTest, Linear) { TEST(MulTest, Linear) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 2, 2); input.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);
@ -75,17 +75,17 @@ TEST(MulTest, Linear) {
} }
TEST(ApplyMaskTest, MaskChannel1) { TEST(ApplyMaskTest, MaskChannel1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 2, 2); input.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 mask; TensorRef<BHWC> mask;
mask.type = DataType::FLOAT32; mask.type = DataType::FLOAT32;
mask.ref = 1; mask.ref = 1;
mask.shape = BHWC(1, 1, 2, 1); mask.shape = BHWC(1, 1, 2, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);
@ -99,17 +99,17 @@ TEST(ApplyMaskTest, MaskChannel1) {
} }
TEST(ApplyMaskTest, MaskChannelEqualsToInputChannel) { TEST(ApplyMaskTest, MaskChannelEqualsToInputChannel) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 2, 2); input.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 mask; TensorRef<BHWC> mask;
mask.type = DataType::FLOAT32; mask.type = DataType::FLOAT32;
mask.ref = 1; mask.ref = 1;
mask.shape = BHWC(1, 1, 2, 2); mask.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);

View File

@ -34,12 +34,12 @@ namespace {
void TestPadOperation(const HWC& prepend, const HWC& append, void TestPadOperation(const HWC& prepend, const HWC& append,
const BHWC& output_shape, std::vector<float>&& expected) { const BHWC& output_shape, std::vector<float>&& expected) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 1, 1); input.shape = BHWC(1, 1, 1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = output_shape; output.shape = output_shape;

View File

@ -35,12 +35,12 @@ namespace gl {
namespace { namespace {
TEST(PoolingTest, MaxKernel2x2Stride2x2WithIndices) { TEST(PoolingTest, MaxKernel2x2Stride2x2WithIndices) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 4, 4, 1); input.shape = BHWC(1, 4, 4, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);
@ -70,12 +70,12 @@ TEST(PoolingTest, MaxKernel2x2Stride2x2WithIndices) {
} }
TEST(PoolingTest, MaxKernel2x2Stride2x2WithoutIndices) { TEST(PoolingTest, MaxKernel2x2Stride2x2WithoutIndices) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 4, 4, 1); input.shape = BHWC(1, 4, 4, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);
@ -96,12 +96,12 @@ TEST(PoolingTest, MaxKernel2x2Stride2x2WithoutIndices) {
} }
TEST(PoolingTest, AverageKernel2x2Stride2x2) { TEST(PoolingTest, AverageKernel2x2Stride2x2) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 4, 4, 1); input.shape = BHWC(1, 4, 4, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);

View File

@ -29,7 +29,7 @@ namespace gl {
namespace { namespace {
TEST(PReluTest, LinearAlphaNoClip) { TEST(PReluTest, LinearAlphaNoClip) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
@ -42,7 +42,7 @@ TEST(PReluTest, LinearAlphaNoClip) {
alpha.data = {2}; alpha.data = {2};
attr.alpha = std::move(alpha); attr.alpha = std::move(alpha);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);
@ -55,7 +55,7 @@ TEST(PReluTest, LinearAlphaNoClip) {
} }
TEST(PReluTest, LinearAlphaWithClip) { TEST(PReluTest, LinearAlphaWithClip) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
@ -68,7 +68,7 @@ TEST(PReluTest, LinearAlphaWithClip) {
alpha.data = {2}; alpha.data = {2};
attr.alpha = std::move(alpha); attr.alpha = std::move(alpha);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);
@ -81,7 +81,7 @@ TEST(PReluTest, LinearAlphaWithClip) {
} }
TEST(PReluTest, 3DAlphaNoClip) { TEST(PReluTest, 3DAlphaNoClip) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
@ -95,7 +95,7 @@ TEST(PReluTest, 3DAlphaNoClip) {
alpha.data = {1, 2, 2, 2}; alpha.data = {1, 2, 2, 2};
attr.alpha = std::move(alpha); attr.alpha = std::move(alpha);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);
@ -107,7 +107,7 @@ TEST(PReluTest, 3DAlphaNoClip) {
} }
TEST(PReluTest, 3DAlphaWithClip) { TEST(PReluTest, 3DAlphaWithClip) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
@ -121,7 +121,7 @@ TEST(PReluTest, 3DAlphaWithClip) {
alpha.data = {1, 2, 2, 2}; alpha.data = {1, 2, 2, 2};
attr.alpha = std::move(alpha); attr.alpha = std::move(alpha);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);

View File

@ -33,8 +33,8 @@ class ReluTest : public ::testing::Test {
ReluTest() = default; ReluTest() = default;
~ReluTest() override = default; ~ReluTest() override = default;
TensorRefFloat32 GetTensorRef(int ref) { TensorRef<BHWC> GetTensorRef(int ref) {
TensorRefFloat32 tensor_ref; TensorRef<BHWC> tensor_ref;
tensor_ref.type = DataType::FLOAT32; tensor_ref.type = DataType::FLOAT32;
tensor_ref.ref = ref; tensor_ref.ref = ref;
tensor_ref.shape = BHWC(1, 2, 2, 1); tensor_ref.shape = BHWC(1, 2, 2, 1);

View File

@ -31,12 +31,12 @@ namespace gl {
namespace { namespace {
TEST(Reshape, 1x2x3To3x2x1) { TEST(Reshape, 1x2x3To3x2x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 2, 3); input.shape = BHWC(1, 1, 2, 3);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 3, 2, 1); output.shape = BHWC(1, 3, 2, 1);
@ -53,12 +53,12 @@ TEST(Reshape, 1x2x3To3x2x1) {
} }
TEST(Reshape, 3x1x2To2x1x3) { TEST(Reshape, 3x1x2To2x1x3) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 3, 1, 2); input.shape = BHWC(1, 3, 1, 2);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 1, 3); output.shape = BHWC(1, 2, 1, 3);
@ -75,12 +75,12 @@ TEST(Reshape, 3x1x2To2x1x3) {
} }
TEST(Reshape, 1x1x4To2x2x1) { TEST(Reshape, 1x1x4To2x2x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 1, 4); input.shape = BHWC(1, 1, 1, 4);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);
@ -96,12 +96,12 @@ TEST(Reshape, 1x1x4To2x2x1) {
} }
TEST(Reshape, BatchIsUnsupported) { TEST(Reshape, BatchIsUnsupported) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(4, 1, 1, 1); input.shape = BHWC(4, 1, 1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);

View File

@ -31,12 +31,12 @@ namespace gl {
namespace { namespace {
TEST(SliceTest, Identity) { TEST(SliceTest, Identity) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 2, 2); input.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);
@ -54,12 +54,12 @@ TEST(SliceTest, Identity) {
} }
TEST(SliceTest, NegativeEnds) { TEST(SliceTest, NegativeEnds) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 2, 2); input.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);
@ -77,12 +77,12 @@ TEST(SliceTest, NegativeEnds) {
} }
TEST(SliceTest, NegativeEndsNonZeroStarts) { TEST(SliceTest, NegativeEndsNonZeroStarts) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 2, 2); input.shape = BHWC(1, 1, 2, 2);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 1, 1, 1); output.shape = BHWC(1, 1, 1, 1);
@ -100,12 +100,12 @@ TEST(SliceTest, NegativeEndsNonZeroStarts) {
} }
TEST(SliceTest, StridesByHeight) { TEST(SliceTest, StridesByHeight) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 4, 1, 1); input.shape = BHWC(1, 4, 1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 1, 1); output.shape = BHWC(1, 2, 1, 1);
@ -123,12 +123,12 @@ TEST(SliceTest, StridesByHeight) {
} }
TEST(SliceTest, StridesByWidth) { TEST(SliceTest, StridesByWidth) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 4, 1); input.shape = BHWC(1, 1, 4, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 1, 2, 1); output.shape = BHWC(1, 1, 2, 1);
@ -146,12 +146,12 @@ TEST(SliceTest, StridesByWidth) {
} }
TEST(SliceTest, StridesByChannels) { TEST(SliceTest, StridesByChannels) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 1, 4); input.shape = BHWC(1, 1, 1, 4);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 1, 1, 1); output.shape = BHWC(1, 1, 1, 1);

View File

@ -32,12 +32,12 @@ namespace gl {
namespace { namespace {
TEST(SoftmaxTest, WorksForChannelsAxis) { TEST(SoftmaxTest, WorksForChannelsAxis) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);
@ -53,12 +53,12 @@ TEST(SoftmaxTest, WorksForChannelsAxis) {
} }
TEST(SoftmaxTest, DoesNotWorkForHeightAxis) { TEST(SoftmaxTest, DoesNotWorkForHeightAxis) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);
@ -75,12 +75,12 @@ TEST(SoftmaxTest, DoesNotWorkForHeightAxis) {
} }
TEST(SoftmaxTest, DoesNotWorkForWidthAxis) { TEST(SoftmaxTest, DoesNotWorkForWidthAxis) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);

View File

@ -37,8 +37,8 @@ namespace gpu {
namespace gl { namespace gl {
SingleOpModel::SingleOpModel(Operation&& operation, SingleOpModel::SingleOpModel(Operation&& operation,
const std::vector<TensorRefFloat32>& inputs, const std::vector<TensorRef<BHWC>>& inputs,
const std::vector<TensorRefFloat32>& outputs) { const std::vector<TensorRef<BHWC>>& outputs) {
auto node = graph_.NewNode(); auto node = graph_.NewNode();
node->operation = std::move(operation); node->operation = std::move(operation);

View File

@ -41,8 +41,8 @@ class SingleOpModel {
public: public:
SingleOpModel() = delete; SingleOpModel() = delete;
SingleOpModel(Operation&& operation, SingleOpModel(Operation&& operation,
const std::vector<TensorRefFloat32>& inputs, const std::vector<TensorRef<BHWC>>& inputs,
const std::vector<TensorRefFloat32>& outputs); const std::vector<TensorRef<BHWC>>& outputs);
virtual ~SingleOpModel() = default; virtual ~SingleOpModel() = default;

View File

@ -31,7 +31,7 @@ namespace gl {
namespace { namespace {
TEST(TransposeConvTest, O2H2W1I1Stride1x1DAdjacent1x1) { TEST(TransposeConvTest, O2H2W1I1Stride1x1DAdjacent1x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
@ -54,7 +54,7 @@ TEST(TransposeConvTest, O2H2W1I1Stride1x1DAdjacent1x1) {
attr.adjacent = HW(1, 1); attr.adjacent = HW(1, 1);
attr.stride = HW(1, 1); attr.stride = HW(1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 2, 2, 2); output.shape = BHWC(1, 2, 2, 2);
@ -69,7 +69,7 @@ TEST(TransposeConvTest, O2H2W1I1Stride1x1DAdjacent1x1) {
} }
TEST(TransposeConvTest, O1H2W2I1Stride1x1Adjacent2x2) { TEST(TransposeConvTest, O1H2W2I1Stride1x1Adjacent2x2) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 3, 3, 1); input.shape = BHWC(1, 3, 3, 1);
@ -92,7 +92,7 @@ TEST(TransposeConvTest, O1H2W2I1Stride1x1Adjacent2x2) {
attr.padding.appended = HW(0, 0); attr.padding.appended = HW(0, 0);
attr.stride = HW(1, 1); attr.stride = HW(1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 1, 1, 1); output.shape = BHWC(1, 1, 1, 1);
@ -106,7 +106,7 @@ TEST(TransposeConvTest, O1H2W2I1Stride1x1Adjacent2x2) {
} }
TEST(TransposeConvTest, O1H3W3I1Stride1x1Adjacent1x1) { TEST(TransposeConvTest, O1H3W3I1Stride1x1Adjacent1x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
@ -129,7 +129,7 @@ TEST(TransposeConvTest, O1H3W3I1Stride1x1Adjacent1x1) {
attr.padding.appended = HW(0, 0); attr.padding.appended = HW(0, 0);
attr.stride = HW(1, 1); attr.stride = HW(1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 1, 1, 1); output.shape = BHWC(1, 1, 1, 1);
@ -143,7 +143,7 @@ TEST(TransposeConvTest, O1H3W3I1Stride1x1Adjacent1x1) {
} }
TEST(TransposeConvTest, O2H1W1I2Stride1x1Dilation1x1) { TEST(TransposeConvTest, O2H1W1I2Stride1x1Dilation1x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 1, 2); input.shape = BHWC(1, 2, 1, 2);
@ -166,7 +166,7 @@ TEST(TransposeConvTest, O2H1W1I2Stride1x1Dilation1x1) {
attr.padding.appended = HW(0, 0); attr.padding.appended = HW(0, 0);
attr.stride = HW(1, 1); attr.stride = HW(1, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 2, 1, 2); output.shape = BHWC(1, 2, 1, 2);
@ -180,7 +180,7 @@ TEST(TransposeConvTest, O2H1W1I2Stride1x1Dilation1x1) {
} }
TEST(TransposeConvTest, O1H1W1I1Stride2x2Dilation1x1) { TEST(TransposeConvTest, O1H1W1I1Stride2x2Dilation1x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 3, 3, 1); input.shape = BHWC(1, 3, 3, 1);
@ -204,7 +204,7 @@ TEST(TransposeConvTest, O1H1W1I1Stride2x2Dilation1x1) {
attr.padding.appended = HW(0, 0); attr.padding.appended = HW(0, 0);
attr.stride = HW(2, 2); attr.stride = HW(2, 2);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 3; output.ref = 3;
output.shape = BHWC(1, 1, 1, 1); output.shape = BHWC(1, 1, 1, 1);

View File

@ -31,12 +31,12 @@ namespace gl {
namespace { namespace {
TEST(UpsamplingBilinearTest, 1x1x2To2x2x2) { TEST(UpsamplingBilinearTest, 1x1x2To2x2x2) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 1, 2); input.shape = BHWC(1, 1, 1, 2);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 2, 2); output.shape = BHWC(1, 2, 2, 2);
@ -56,12 +56,12 @@ TEST(UpsamplingBilinearTest, 1x1x2To2x2x2) {
} }
TEST(UpsamplingBilinearTest, 1x2x1To1x4x1) { TEST(UpsamplingBilinearTest, 1x2x1To1x4x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 1, 2, 1); input.shape = BHWC(1, 1, 2, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 1, 4, 1); output.shape = BHWC(1, 1, 4, 1);
@ -80,12 +80,12 @@ TEST(UpsamplingBilinearTest, 1x2x1To1x4x1) {
} }
TEST(UpsamplingBilinearTest, 2x2x1To4x4x1) { TEST(UpsamplingBilinearTest, 2x2x1To4x4x1) {
TensorRefFloat32 input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
input.shape = BHWC(1, 2, 2, 1); input.shape = BHWC(1, 2, 2, 1);
TensorRefFloat32 output; TensorRef<BHWC> output;
output.type = DataType::FLOAT32; output.type = DataType::FLOAT32;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 4, 4, 1); output.shape = BHWC(1, 4, 4, 1);

View File

@ -32,7 +32,7 @@ Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor,
return CreateReadOnlyShaderStorageBuffer<float>(transposed, gl_buffer); return CreateReadOnlyShaderStorageBuffer<float>(transposed, gl_buffer);
} }
Status CreatePHWC4BufferFromTensorRef(const TensorRefFloat32& tensor_ref, Status CreatePHWC4BufferFromTensorRef(const TensorRef<BHWC>& tensor_ref,
GlBuffer* gl_buffer) { GlBuffer* gl_buffer) {
return CreateReadWriteShaderStorageBuffer<float>( return CreateReadWriteShaderStorageBuffer<float>(
GetElementsSizeForPHWC4(tensor_ref.shape), gl_buffer); GetElementsSizeForPHWC4(tensor_ref.shape), gl_buffer);

View File

@ -72,7 +72,7 @@ Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor,
// Creates read-write buffer for the given tensor shape, where data layout is // Creates read-write buffer for the given tensor shape, where data layout is
// supposed to be PHWC4. // supposed to be PHWC4.
Status CreatePHWC4BufferFromTensorRef(const TensorRefFloat32& tensor_ref, Status CreatePHWC4BufferFromTensorRef(const TensorRef<BHWC>& tensor_ref,
GlBuffer* gl_buffer); GlBuffer* gl_buffer);
// Copies data from a buffer that holds data in PHWC4 layout to the given // Copies data from a buffer that holds data in PHWC4 layout to the given

View File

@ -148,7 +148,7 @@ class Delegate {
// TODO(impjdi): Remove code duplication. // TODO(impjdi): Remove code duplication.
auto values = graph.values(); auto values = graph.values();
auto find_value = [&](int tensor_index) -> Value<TensorRefFloat32>* { auto find_value = [&](int tensor_index) -> Value<TensorRef<BHWC>>* {
for (auto value : values) { for (auto value : values) {
if (value->tensor.ref == tensor_index) return value; if (value->tensor.ref == tensor_index) return value;
} }