Get rid of TensorRefFloat32 specialization. It no longer keeps type in template params, therefore, it is better to use ref explicitly.
PiperOrigin-RevId: 253738844
This commit is contained in:
parent
3557034977
commit
5044a18831
@ -68,12 +68,12 @@ int64_t DimensionsProduct(const TfLiteIntArray& dims) {
|
||||
// will turn into:
|
||||
// node(copy(output)) <- passthrough_node(output)
|
||||
Status NewPassthroughNode(GraphFloat32* graph, Node* node,
|
||||
const Value<TensorRefFloat32>* output,
|
||||
const Value<TensorRef<BHWC>>* output,
|
||||
Node** passthru_node) {
|
||||
*passthru_node = graph->NewNode();
|
||||
// Make copies for every output in the original node.
|
||||
RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id));
|
||||
Value<TensorRefFloat32>* copy_output = graph->NewValue();
|
||||
Value<TensorRef<BHWC>>* copy_output = graph->NewValue();
|
||||
RETURN_IF_ERROR(graph->SetProducer(node->id, copy_output->id));
|
||||
RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id));
|
||||
copy_output->tensor = output->tensor;
|
||||
@ -265,13 +265,13 @@ class ObjectReader {
|
||||
public:
|
||||
ObjectReader(GraphFloat32* graph, TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
std::vector<Value<TensorRefFloat32>*>* tensor_to_value)
|
||||
std::vector<Value<TensorRef<BHWC>>*>* tensor_to_value)
|
||||
: graph_(graph),
|
||||
context_(context),
|
||||
tflite_node_(tflite_node),
|
||||
tensor_to_value_(tensor_to_value) {}
|
||||
|
||||
Status ReadValue(uint32_t idx, Value<TensorRefFloat32>** value) {
|
||||
Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value) {
|
||||
if (idx >= tflite_node_->inputs->size) {
|
||||
return OutOfRangeError(StrCat("ReadValue: input tensor index: ", idx));
|
||||
}
|
||||
@ -319,7 +319,7 @@ class ObjectReader {
|
||||
tflite_node_->outputs->size));
|
||||
}
|
||||
int output_tensor_idx = tflite_node_->outputs->data[id];
|
||||
Value<TensorRefFloat32>* value;
|
||||
Value<TensorRef<BHWC>>* value;
|
||||
RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value));
|
||||
RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id));
|
||||
return OkStatus();
|
||||
@ -333,13 +333,13 @@ class ObjectReader {
|
||||
}
|
||||
|
||||
Status AddInput(const Node* node, uint32_t idx) {
|
||||
Value<TensorRefFloat32>* input;
|
||||
Value<TensorRef<BHWC>>* input;
|
||||
RETURN_IF_ERROR(ReadValue(idx, &input));
|
||||
return graph_->AddConsumer(node->id, input->id);
|
||||
}
|
||||
|
||||
Status ReadValueByTensorIdx(uint32_t tensor_idx,
|
||||
Value<TensorRefFloat32>** value) {
|
||||
Value<TensorRef<BHWC>>** value) {
|
||||
if (tensor_idx >= tensor_to_value_->size()) {
|
||||
return OutOfRangeError(
|
||||
StrCat("ReadValue: input tensor index: ", tensor_idx));
|
||||
@ -350,7 +350,7 @@ class ObjectReader {
|
||||
return NotFoundError(
|
||||
StrCat("ReadValue: value is a constant tensor: ", tensor_idx));
|
||||
}
|
||||
Value<TensorRefFloat32>* value = graph_->NewValue();
|
||||
Value<TensorRef<BHWC>>* value = graph_->NewValue();
|
||||
RETURN_IF_ERROR(
|
||||
ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor));
|
||||
value->tensor.ref = tensor_idx;
|
||||
@ -364,7 +364,7 @@ class ObjectReader {
|
||||
GraphFloat32* graph_ = nullptr;
|
||||
const TfLiteContext* context_ = nullptr;
|
||||
const TfLiteNode* tflite_node_ = nullptr;
|
||||
std::vector<Value<TensorRefFloat32>*>* tensor_to_value_;
|
||||
std::vector<Value<TensorRef<BHWC>>*>* tensor_to_value_;
|
||||
};
|
||||
|
||||
Status CheckInputsOutputs(const TfLiteContext* context,
|
||||
@ -639,7 +639,7 @@ class Conv2DOperationParser : public TFLiteOperationParser {
|
||||
|
||||
// Creates a simple node that holds tensor value.
|
||||
Status NewConstNode(TensorFloat32 t, GraphFloat32* graph,
|
||||
Value<TensorRefFloat32>** value) {
|
||||
Value<TensorRef<BHWC>>** value) {
|
||||
ConstTensorAttributes attr;
|
||||
attr.tensor = std::move(t);
|
||||
Node* node = graph->NewNode();
|
||||
@ -677,16 +677,16 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
|
||||
ConcatAttributes attr;
|
||||
// Read inputs first to make sure const node is added to a graph before
|
||||
// concat node to ensure topological order.
|
||||
std::vector<const Value<TensorRefFloat32>*> inputs;
|
||||
std::vector<const Value<TensorRef<BHWC>>*> inputs;
|
||||
for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
|
||||
Value<TensorRefFloat32>* value;
|
||||
Value<TensorRef<BHWC>>* value;
|
||||
const auto status = reader->ReadValue(idx, &value);
|
||||
if (status.ok()) {
|
||||
inputs.push_back(value);
|
||||
} else {
|
||||
TensorFloat32 tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
|
||||
Value<TensorRefFloat32>* value;
|
||||
Value<TensorRef<BHWC>>* value;
|
||||
RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
|
||||
inputs.push_back(value);
|
||||
}
|
||||
@ -695,7 +695,7 @@ class ConcatenationOperationParser : public TFLiteOperationParser {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::CONCAT);
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
for (const Value<TensorRefFloat32>* input : inputs) {
|
||||
for (const Value<TensorRef<BHWC>>* input : inputs) {
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
|
||||
}
|
||||
|
||||
@ -1143,11 +1143,11 @@ class LstmOperationParser : public TFLiteOperationParser {
|
||||
lstm_attr.kernel_type = LstmKernelType::BASIC;
|
||||
lstm_node->operation.attributes = lstm_attr;
|
||||
|
||||
Value<TensorRefFloat32>* concat_temp;
|
||||
Value<TensorRef<BHWC>>* concat_temp;
|
||||
int concat_tensor_idx = tflite_node->outputs->data[2];
|
||||
RETURN_IF_ERROR(
|
||||
reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
|
||||
Value<TensorRefFloat32>* activ_temp;
|
||||
Value<TensorRef<BHWC>>* activ_temp;
|
||||
int activ_tensor_idx = tflite_node->outputs->data[3];
|
||||
RETURN_IF_ERROR(
|
||||
reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
|
||||
@ -1521,7 +1521,7 @@ class FullyConnectedOperationParser : public TFLiteOperationParser {
|
||||
if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
|
||||
auto& reshape = node;
|
||||
conv = graph->NewNode(); // reset conv pointer!
|
||||
Value<TensorRefFloat32>* reshaped_value = graph->NewValue();
|
||||
Value<TensorRef<BHWC>>* reshaped_value = graph->NewValue();
|
||||
reshaped_value->tensor.shape = BHWC(1, 1, 1, weights.shape.w);
|
||||
RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id));
|
||||
reshape->operation.type = ToString(OperationType::RESHAPE);
|
||||
@ -1558,7 +1558,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::SLICE);
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
Value<TensorRefFloat32>* input;
|
||||
Value<TensorRef<BHWC>>* input;
|
||||
RETURN_IF_ERROR(reader->ReadValue(0, &input));
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
|
||||
|
||||
@ -1721,7 +1721,7 @@ class TransposeConvOperationParser : public TFLiteOperationParser {
|
||||
ObjectReader* reader) final {
|
||||
auto* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
|
||||
Value<TensorRefFloat32>* input;
|
||||
Value<TensorRef<BHWC>>* input;
|
||||
RETURN_IF_ERROR(reader->ReadValue(2, &input));
|
||||
RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
@ -1970,7 +1970,7 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
||||
} // namespace
|
||||
|
||||
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
||||
TensorRefFloat32* tensor_ref) {
|
||||
TensorRef<BHWC>* tensor_ref) {
|
||||
tensor_ref->type = ToDataType(tflite_tensor.type);
|
||||
const TfLiteIntArray* dims = tflite_tensor.dims;
|
||||
switch (dims->size) {
|
||||
@ -2128,8 +2128,8 @@ Status BuildModel(TfLiteContext* context,
|
||||
}
|
||||
operations.push_back(std::move(op_parser));
|
||||
}
|
||||
std::vector<Value<TensorRefFloat32>*> tensor_to_value(context->tensors_size,
|
||||
nullptr);
|
||||
std::vector<Value<TensorRef<BHWC>>*> tensor_to_value(context->tensors_size,
|
||||
nullptr);
|
||||
for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
|
||||
TfLiteNode* tflite_node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
|
@ -38,7 +38,7 @@ Status BuildModel(TfLiteContext* context,
|
||||
|
||||
// Module-internal converter, exposed for unit testing purpose only.
|
||||
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
||||
TensorRefFloat32* tensor_ref);
|
||||
TensorRef<BHWC>* tensor_ref);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -34,7 +34,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank0) {
|
||||
tflite_tensor.type = TfLiteType::kTfLiteFloat32;
|
||||
tflite_tensor.dims = TfLiteIntArrayCreate(1);
|
||||
tflite_tensor.dims->data[0] = 4;
|
||||
TensorRefFloat32 tensor_ref;
|
||||
TensorRef<BHWC> tensor_ref;
|
||||
const auto status =
|
||||
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
|
||||
TfLiteIntArrayFree(tflite_tensor.dims);
|
||||
@ -49,7 +49,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank1) {
|
||||
tflite_tensor.dims = TfLiteIntArrayCreate(2);
|
||||
tflite_tensor.dims->data[0] = 4;
|
||||
tflite_tensor.dims->data[1] = 5;
|
||||
TensorRefFloat32 tensor_ref;
|
||||
TensorRef<BHWC> tensor_ref;
|
||||
const auto status =
|
||||
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
|
||||
TfLiteIntArrayFree(tflite_tensor.dims);
|
||||
@ -65,7 +65,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank2) {
|
||||
tflite_tensor.dims->data[0] = 4;
|
||||
tflite_tensor.dims->data[1] = 5;
|
||||
tflite_tensor.dims->data[2] = 6;
|
||||
TensorRefFloat32 tensor_ref;
|
||||
TensorRef<BHWC> tensor_ref;
|
||||
const auto status =
|
||||
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
|
||||
TfLiteIntArrayFree(tflite_tensor.dims);
|
||||
@ -82,7 +82,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank3) {
|
||||
tflite_tensor.dims->data[1] = 5;
|
||||
tflite_tensor.dims->data[2] = 6;
|
||||
tflite_tensor.dims->data[3] = 7;
|
||||
TensorRefFloat32 tensor_ref;
|
||||
TensorRef<BHWC> tensor_ref;
|
||||
const auto status =
|
||||
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
|
||||
TfLiteIntArrayFree(tflite_tensor.dims);
|
||||
@ -95,7 +95,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankLT0) {
|
||||
TfLiteTensor tflite_tensor;
|
||||
tflite_tensor.type = TfLiteType::kTfLiteFloat32;
|
||||
tflite_tensor.dims = TfLiteIntArrayCreate(0);
|
||||
TensorRefFloat32 tensor_ref;
|
||||
TensorRef<BHWC> tensor_ref;
|
||||
const auto status =
|
||||
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
|
||||
TfLiteIntArrayFree(tflite_tensor.dims);
|
||||
@ -107,7 +107,7 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankGT3) {
|
||||
TfLiteTensor tflite_tensor;
|
||||
tflite_tensor.type = TfLiteType::kTfLiteFloat32;
|
||||
tflite_tensor.dims = TfLiteIntArrayCreate(5);
|
||||
TensorRefFloat32 tensor_ref;
|
||||
TensorRef<BHWC> tensor_ref;
|
||||
const auto status =
|
||||
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
|
||||
TfLiteIntArrayFree(tflite_tensor.dims);
|
||||
|
@ -31,8 +31,8 @@ TEST(Model, SingleNode) {
|
||||
// graph_input -> node -> graph_output
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||
|
||||
@ -52,9 +52,9 @@ TEST(Model, SingleNodeMultipleOutputs) {
|
||||
// graph_input -> node -> (graph_output1, graph_output2)
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||
Value<TensorRefFloat32>* graph_output1 = graph.NewValue();
|
||||
Value<TensorRefFloat32>* graph_output2 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output1 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output2 = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, graph_output1->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, graph_output2->id).ok());
|
||||
@ -67,7 +67,7 @@ TEST(Model, SingleNodeMultipleOutputs) {
|
||||
TEST(Model, SetSameConsumer) {
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||
EXPECT_FALSE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||
}
|
||||
@ -76,8 +76,8 @@ TEST(Model, RemoveConsumer) {
|
||||
// (graph_input1, graph_input2) -> node
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_input1 = graph.NewValue();
|
||||
Value<TensorRefFloat32>* graph_input2 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input1 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input2 = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input1->id).ok());
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input2->id).ok());
|
||||
EXPECT_THAT(graph.FindConsumers(graph_input1->id),
|
||||
@ -101,7 +101,7 @@ TEST(Model, RemoveConsumer) {
|
||||
TEST(Model, SetSameProducer) {
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||
EXPECT_FALSE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||
}
|
||||
@ -109,7 +109,7 @@ TEST(Model, SetSameProducer) {
|
||||
TEST(Model, RemoveProducer) {
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||
EXPECT_THAT(graph.inputs(), UnorderedElementsAre());
|
||||
@ -126,8 +126,8 @@ TEST(Model, RemoveProducer) {
|
||||
TEST(Model, RemoveSimpleNodeDegenerateCase) {
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
|
||||
@ -145,9 +145,9 @@ TEST(Model, RemoveSimpleNodeNoPreviousNode) {
|
||||
GraphFloat32 graph;
|
||||
Node* simple_node = graph.NewNode();
|
||||
Node* consumer_node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||
Value<TensorRefFloat32>* value = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(simple_node->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(simple_node->id, value->id).ok());
|
||||
@ -167,9 +167,9 @@ TEST(Model, RemoveSimpleNodeNoAfterNodes) {
|
||||
GraphFloat32 graph;
|
||||
Node* simple_node = graph.NewNode();
|
||||
Node* producer_node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||
Value<TensorRefFloat32>* value = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(simple_node->id, value->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(simple_node->id, graph_output->id).ok());
|
||||
@ -190,10 +190,10 @@ TEST(Model, RemoveSimpleNodeGeneralCase) {
|
||||
Node* simple_node = graph.NewNode();
|
||||
Node* producer_node = graph.NewNode();
|
||||
Node* consumer_node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||
Value<TensorRefFloat32>* value0 = graph.NewValue();
|
||||
Value<TensorRefFloat32>* value1 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value0 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value1 = graph.NewValue();
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(producer_node->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(producer_node->id, value0->id).ok());
|
||||
@ -217,14 +217,14 @@ TEST(Model, CircularDependency) {
|
||||
{
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* value = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node->id, value->id).ok());
|
||||
EXPECT_FALSE(graph.SetProducer(node->id, value->id).ok());
|
||||
}
|
||||
{
|
||||
GraphFloat32 graph;
|
||||
Node* node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* value = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
ASSERT_TRUE(graph.SetProducer(node->id, value->id).ok());
|
||||
EXPECT_FALSE(graph.AddConsumer(node->id, value->id).ok());
|
||||
}
|
||||
@ -237,8 +237,8 @@ TEST(Model, ReassignValue) {
|
||||
GraphFloat32 graph;
|
||||
Node* node1 = graph.NewNode();
|
||||
Node* node2 = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node1->id, graph_output->id).ok());
|
||||
ASSERT_TRUE(graph.AddConsumer(node2->id, graph_input->id).ok());
|
||||
@ -264,9 +264,9 @@ TEST(Model, DeleteValue) {
|
||||
GraphFloat32 graph;
|
||||
Node* node1 = graph.NewNode();
|
||||
Node* node2 = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||
Value<TensorRefFloat32>* value = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
|
||||
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
|
||||
@ -305,10 +305,10 @@ TEST(Model, DeleteNode) {
|
||||
Node* node1 = graph.NewNode();
|
||||
Node* node2 = graph.NewNode();
|
||||
Node* node3 = graph.NewNode();
|
||||
Value<TensorRefFloat32>* graph_input = graph.NewValue();
|
||||
Value<TensorRefFloat32>* graph_output = graph.NewValue();
|
||||
Value<TensorRefFloat32>* graph_output2 = graph.NewValue();
|
||||
Value<TensorRefFloat32>* value = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_input = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* graph_output2 = graph.NewValue();
|
||||
Value<TensorRef<BHWC>>* value = graph.NewValue();
|
||||
ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
|
||||
ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
|
||||
ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
|
||||
|
@ -57,11 +57,11 @@ TEST(MergeConvolutionWithAddTest, Smoke) {
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok());
|
||||
|
||||
Value<TensorRefFloat32>* output;
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
Value<TensorRefFloat32>* link1;
|
||||
Value<TensorRef<BHWC>>* link1;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, add_node, &link1).ok());
|
||||
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
@ -108,11 +108,11 @@ TEST(MergeAddWithConvolutionTest, Smoke) {
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
|
||||
|
||||
Value<TensorRefFloat32>* output;
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
Value<TensorRefFloat32>* link1;
|
||||
Value<TensorRef<BHWC>>* link1;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, add_node, conv_node, &link1).ok());
|
||||
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
|
@ -58,11 +58,11 @@ TEST(MergeConvolutionWithMulTest, Smoke) {
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok());
|
||||
|
||||
Value<TensorRefFloat32>* output;
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, mul_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
Value<TensorRefFloat32>* link1;
|
||||
Value<TensorRef<BHWC>>* link1;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, mul_node, &link1).ok());
|
||||
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
@ -109,11 +109,11 @@ TEST(MergeMulWithConvolutionTest, Smoke) {
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(mul_node->id, input->id).ok());
|
||||
|
||||
Value<TensorRefFloat32>* output;
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
Value<TensorRefFloat32>* link1;
|
||||
Value<TensorRef<BHWC>>* link1;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, mul_node, conv_node, &link1).ok());
|
||||
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
|
@ -68,16 +68,16 @@ TEST(MakeFullyConnected, Smoke) {
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(conv1x1_node0->id, input->id).ok());
|
||||
|
||||
Value<TensorRefFloat32>* output;
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, conv1x1_node2, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 1, 1, 32);
|
||||
|
||||
Value<TensorRefFloat32>* link1;
|
||||
Value<TensorRef<BHWC>>* link1;
|
||||
ASSERT_TRUE(
|
||||
ConnectTwoNodes(&graph, conv1x1_node0, conv4x4_node1, &link1).ok());
|
||||
link1->tensor.shape = BHWC(1, 4, 4, 16);
|
||||
|
||||
Value<TensorRefFloat32>* link2;
|
||||
Value<TensorRef<BHWC>>* link2;
|
||||
ASSERT_TRUE(
|
||||
ConnectTwoNodes(&graph, conv4x4_node1, conv1x1_node2, &link2).ok());
|
||||
link2->tensor.shape = BHWC(1, 1, 1, 16);
|
||||
|
@ -38,7 +38,7 @@ TEST(MakePadding, Smoke) {
|
||||
attr.axis = Axis::HEIGHT;
|
||||
concat_node->operation.attributes = attr;
|
||||
|
||||
Value<TensorRefFloat32>* output;
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, concat_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 7, 3, 5);
|
||||
|
||||
@ -50,7 +50,7 @@ TEST(MakePadding, Smoke) {
|
||||
std::vector<float>(const_attr.tensor.shape.DimensionsProduct(), 0);
|
||||
const_node->operation.attributes = const_attr;
|
||||
|
||||
Value<TensorRefFloat32>* const_link;
|
||||
Value<TensorRef<BHWC>>* const_link;
|
||||
ASSERT_TRUE(
|
||||
ConnectTwoNodes(&graph, const_node, concat_node, &const_link).ok());
|
||||
const_link->tensor.shape = const_attr.tensor.shape;
|
||||
|
@ -62,15 +62,15 @@ TEST(MatchDilatedConvolutionTest, MakesDilatedConvolution) {
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(sb_node->id, input->id).ok());
|
||||
|
||||
Value<TensorRefFloat32>* output;
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, bs_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 95, 1, 17);
|
||||
|
||||
Value<TensorRefFloat32>* sb_link;
|
||||
Value<TensorRef<BHWC>>* sb_link;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, sb_node, dw_node, &sb_link).ok());
|
||||
sb_link->tensor.shape = BHWC(21, 128, 1, 17);
|
||||
|
||||
Value<TensorRefFloat32>* bs_link;
|
||||
Value<TensorRef<BHWC>>* bs_link;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, dw_node, bs_node, &bs_link).ok());
|
||||
bs_link->tensor.shape = BHWC(1, 95, 1, 17);
|
||||
|
||||
|
@ -40,7 +40,7 @@ TEST(MergePaddingWith, Smoke) {
|
||||
pad_node->operation.attributes = attr;
|
||||
|
||||
auto conv_node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* temp;
|
||||
Value<TensorRef<BHWC>>* temp;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node, conv_node, &temp).ok());
|
||||
ASSERT_TRUE(AddOutput(&graph, conv_node, &temp).ok());
|
||||
conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D);
|
||||
@ -77,7 +77,7 @@ TEST(MergePaddingWith, MergeTwo) {
|
||||
pad_node1->operation.attributes = attr;
|
||||
|
||||
auto pad_node2 = graph.NewNode();
|
||||
Value<TensorRefFloat32>* temp;
|
||||
Value<TensorRef<BHWC>>* temp;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node1, pad_node2, &temp).ok());
|
||||
pad_node2->operation.type = ToString(OperationType::PAD);
|
||||
attr.prepended = HWC(0, 0, 0);
|
||||
|
@ -33,12 +33,12 @@ TEST(RemoveSingleInputAdd, Smoke) {
|
||||
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
|
||||
|
||||
auto add_node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* output;
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||
add_node->operation.type = ToString(OperationType::ADD);
|
||||
add_node->operation.attributes = AddAttributes();
|
||||
|
||||
Value<TensorRefFloat32>* temp;
|
||||
Value<TensorRef<BHWC>>* temp;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
|
||||
ASSERT_EQ(2, graph.nodes().size());
|
||||
ASSERT_EQ(3, graph.values().size());
|
||||
@ -61,14 +61,14 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Tensor) {
|
||||
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
|
||||
|
||||
auto add_node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* output;
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||
add_node->operation.type = ToString(OperationType::ADD);
|
||||
AddAttributes attr;
|
||||
attr.param = Tensor<Linear, DataType::FLOAT32>();
|
||||
add_node->operation.attributes = attr;
|
||||
|
||||
Value<TensorRefFloat32>* temp;
|
||||
Value<TensorRef<BHWC>>* temp;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok());
|
||||
ASSERT_EQ(2, graph.nodes().size());
|
||||
ASSERT_EQ(3, graph.values().size());
|
||||
@ -90,11 +90,11 @@ TEST(RemoveSingleInputAdd, DoNotTrigger_Multiple) {
|
||||
ASSERT_TRUE(graph.AddConsumer(node_b->id, input->id).ok());
|
||||
|
||||
auto add_node = graph.NewNode();
|
||||
Value<TensorRefFloat32>* output;
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||
add_node->operation.type = ToString(OperationType::ADD);
|
||||
|
||||
Value<TensorRefFloat32>* temp;
|
||||
Value<TensorRef<BHWC>>* temp;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, node_a, add_node, &temp).ok());
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, node_b, add_node, &temp).ok());
|
||||
ASSERT_EQ(3, graph.nodes().size());
|
||||
@ -115,7 +115,7 @@ TEST(RemoveDegenerateUpsampling, Smoke) {
|
||||
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
|
||||
|
||||
auto node_to_remove = graph.NewNode();
|
||||
Value<TensorRefFloat32>* output;
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 5, 5, 1);
|
||||
node_to_remove->operation.type = ToString(OperationType::UPSAMPLE_2D);
|
||||
@ -124,7 +124,7 @@ TEST(RemoveDegenerateUpsampling, Smoke) {
|
||||
attr.type = UpsamplingType::BILINEAR;
|
||||
node_to_remove->operation.attributes = attr;
|
||||
|
||||
Value<TensorRefFloat32>* link;
|
||||
Value<TensorRef<BHWC>>* link;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok());
|
||||
link->tensor.shape = output->tensor.shape;
|
||||
ASSERT_EQ(2, graph.nodes().size());
|
||||
@ -148,7 +148,7 @@ TEST(RemoveIdentityReshape, Smoke) {
|
||||
ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok());
|
||||
|
||||
auto node_to_remove = graph.NewNode();
|
||||
Value<TensorRefFloat32>* output;
|
||||
Value<TensorRef<BHWC>>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 1, 1, 11);
|
||||
node_to_remove->operation.type = ToString(OperationType::RESHAPE);
|
||||
@ -156,7 +156,7 @@ TEST(RemoveIdentityReshape, Smoke) {
|
||||
attr.new_shape = BHWC(1, 1, 1, 11);
|
||||
node_to_remove->operation.attributes = attr;
|
||||
|
||||
Value<TensorRefFloat32>* link;
|
||||
Value<TensorRef<BHWC>>* link;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok());
|
||||
link->tensor.shape = output->tensor.shape;
|
||||
ASSERT_EQ(2, graph.nodes().size());
|
||||
|
@ -31,7 +31,7 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(AddTest, TwoInputTensorsOfTheSameShape) {
|
||||
TensorRefFloat32 augend, addend, output;
|
||||
TensorRef<BHWC> augend, addend, output;
|
||||
augend.type = DataType::FLOAT32;
|
||||
augend.ref = 0;
|
||||
augend.shape = BHWC(1, 2, 2, 1);
|
||||
@ -57,7 +57,7 @@ TEST(AddTest, TwoInputTensorsOfTheSameShape) {
|
||||
TEST(AddTest, InputTensorAndScalar) {
|
||||
AddAttributes attr;
|
||||
attr.param = 0.1f;
|
||||
TensorRefFloat32 input, output;
|
||||
TensorRef<BHWC> input, output;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 3, 1, 2);
|
||||
@ -75,7 +75,7 @@ TEST(AddTest, InputTensorAndScalar) {
|
||||
}
|
||||
|
||||
TEST(AddTest, InputTensorWithConstandBroadcast) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 2);
|
||||
@ -88,7 +88,7 @@ TEST(AddTest, InputTensorWithConstandBroadcast) {
|
||||
tensor.data.push_back(20.0);
|
||||
attr.param = std::move(tensor);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 2, 2, 2);
|
||||
@ -104,19 +104,19 @@ TEST(AddTest, InputTensorWithConstandBroadcast) {
|
||||
}
|
||||
|
||||
TEST(AddTest, InputTensorWithRuntimeBroadcast) {
|
||||
TensorRefFloat32 input1;
|
||||
TensorRef<BHWC> input1;
|
||||
input1.type = DataType::FLOAT32;
|
||||
input1.ref = 0;
|
||||
input1.shape = BHWC(1, 2, 2, 2);
|
||||
|
||||
TensorRefFloat32 input2;
|
||||
TensorRef<BHWC> input2;
|
||||
input2.type = DataType::FLOAT32;
|
||||
input2.ref = 1;
|
||||
input2.shape = BHWC(1, 1, 1, 2);
|
||||
|
||||
AddAttributes attr;
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 2, 2, 2);
|
||||
|
@ -31,7 +31,7 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(ConcatTest, TwoInputTensorsByUnalignedChannel) {
|
||||
TensorRefFloat32 input1, input2, output;
|
||||
TensorRef<BHWC> input1, input2, output;
|
||||
input1.type = DataType::FLOAT32;
|
||||
input1.ref = 0;
|
||||
input1.shape = BHWC(1, 2, 2, 1);
|
||||
@ -57,7 +57,7 @@ TEST(ConcatTest, TwoInputTensorsByUnalignedChannel) {
|
||||
}
|
||||
|
||||
TEST(ConcatTest, TwoInputTensorsByAlignedChannel) {
|
||||
TensorRefFloat32 input1, input2, output;
|
||||
TensorRef<BHWC> input1, input2, output;
|
||||
input1.type = DataType::FLOAT32;
|
||||
input1.ref = 0;
|
||||
input1.shape = BHWC(1, 1, 1, 4);
|
||||
@ -83,7 +83,7 @@ TEST(ConcatTest, TwoInputTensorsByAlignedChannel) {
|
||||
}
|
||||
|
||||
TEST(ConcatTest, TwoInputTensorsByHeight) {
|
||||
TensorRefFloat32 input1, input2, output;
|
||||
TensorRef<BHWC> input1, input2, output;
|
||||
input1.type = DataType::FLOAT32;
|
||||
input1.ref = 0;
|
||||
input1.shape = BHWC(1, 1, 2, 1);
|
||||
@ -109,7 +109,7 @@ TEST(ConcatTest, TwoInputTensorsByHeight) {
|
||||
}
|
||||
|
||||
TEST(ConcatTest, TwoInputTensorsByWidth) {
|
||||
TensorRefFloat32 input1, input2, output;
|
||||
TensorRef<BHWC> input1, input2, output;
|
||||
input1.type = DataType::FLOAT32;
|
||||
input1.ref = 0;
|
||||
input1.shape = BHWC(1, 2, 1, 1);
|
||||
|
@ -31,7 +31,7 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(ConvTest, O2H2W1I1Stride1x1Dilation1x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
@ -54,7 +54,7 @@ TEST(ConvTest, O2H2W1I1Stride1x1Dilation1x1) {
|
||||
attr.padding.appended = HW(1, 0);
|
||||
attr.strides = HW(1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 2, 2, 2);
|
||||
@ -69,7 +69,7 @@ TEST(ConvTest, O2H2W1I1Stride1x1Dilation1x1) {
|
||||
}
|
||||
|
||||
TEST(ConvTest, O1H2W2I1Stride1x1Dilation2x2) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 3, 3, 1);
|
||||
@ -92,7 +92,7 @@ TEST(ConvTest, O1H2W2I1Stride1x1Dilation2x2) {
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.strides = HW(1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
@ -106,7 +106,7 @@ TEST(ConvTest, O1H2W2I1Stride1x1Dilation2x2) {
|
||||
}
|
||||
|
||||
TEST(ConvTest, O1H3W3I1Stride1x1Dilation1x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
@ -129,7 +129,7 @@ TEST(ConvTest, O1H3W3I1Stride1x1Dilation1x1) {
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.strides = HW(1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
@ -143,7 +143,7 @@ TEST(ConvTest, O1H3W3I1Stride1x1Dilation1x1) {
|
||||
}
|
||||
|
||||
TEST(ConvTest, O2H1W1I2Stride1x1Dilation1x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 1, 2);
|
||||
@ -166,7 +166,7 @@ TEST(ConvTest, O2H1W1I2Stride1x1Dilation1x1) {
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.strides = HW(1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 2, 1, 2);
|
||||
@ -180,7 +180,7 @@ TEST(ConvTest, O2H1W1I2Stride1x1Dilation1x1) {
|
||||
}
|
||||
|
||||
TEST(ConvTest, O1H1W1I1Stride2x2Dilation1x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 3, 3, 1);
|
||||
@ -204,7 +204,7 @@ TEST(ConvTest, O1H1W1I1Stride2x2Dilation1x1) {
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.strides = HW(2, 2);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
|
@ -31,7 +31,7 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(DepthwiseConvTest, O4H1W1I2Strides1x1Dilation1x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 1, 2);
|
||||
@ -55,7 +55,7 @@ TEST(DepthwiseConvTest, O4H1W1I2Strides1x1Dilation1x1) {
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.strides = HW(1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 1, 1, 4);
|
||||
@ -69,7 +69,7 @@ TEST(DepthwiseConvTest, O4H1W1I2Strides1x1Dilation1x1) {
|
||||
}
|
||||
|
||||
TEST(DepthwiseConvTest, O2H1W1I1Strides2x2Dilation1x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 3, 3, 1);
|
||||
@ -93,7 +93,7 @@ TEST(DepthwiseConvTest, O2H1W1I1Strides2x2Dilation1x1) {
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.strides = HW(2, 2);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 2, 2, 2);
|
||||
@ -108,7 +108,7 @@ TEST(DepthwiseConvTest, O2H1W1I1Strides2x2Dilation1x1) {
|
||||
}
|
||||
|
||||
TEST(DepthwiseConvTest, O2H2W2I1Strides1x1Dilation2x2) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 3, 3, 1);
|
||||
@ -132,7 +132,7 @@ TEST(DepthwiseConvTest, O2H2W2I1Strides1x1Dilation2x2) {
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.strides = HW(1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 1, 1, 2);
|
||||
|
@ -33,8 +33,8 @@ class ElementwiseOneArgumentTest : public ::testing::Test {
|
||||
ElementwiseOneArgumentTest() = default;
|
||||
~ElementwiseOneArgumentTest() override = default;
|
||||
|
||||
TensorRefFloat32 GetTensorRef(int ref) {
|
||||
TensorRefFloat32 tensor_ref;
|
||||
TensorRef<BHWC> GetTensorRef(int ref) {
|
||||
TensorRef<BHWC> tensor_ref;
|
||||
tensor_ref.type = DataType::FLOAT32;
|
||||
tensor_ref.ref = ref;
|
||||
tensor_ref.shape = BHWC(1, 2, 2, 1);
|
||||
@ -137,8 +137,8 @@ class ElementwiseTwoArgumentsTest : public ::testing::Test {
|
||||
ElementwiseTwoArgumentsTest() = default;
|
||||
~ElementwiseTwoArgumentsTest() override = default;
|
||||
|
||||
TensorRefFloat32 GetTensorRef(int ref) {
|
||||
TensorRefFloat32 tensor_ref;
|
||||
TensorRef<BHWC> GetTensorRef(int ref) {
|
||||
TensorRef<BHWC> tensor_ref;
|
||||
tensor_ref.type = DataType::FLOAT32;
|
||||
tensor_ref.ref = ref;
|
||||
tensor_ref.shape = BHWC(1, 2, 2, 1);
|
||||
|
@ -31,7 +31,7 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(FullyConnectedTest, MatrixByVectorMultiplication) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 1, 2);
|
||||
@ -50,7 +50,7 @@ TEST(FullyConnectedTest, MatrixByVectorMultiplication) {
|
||||
weights.data = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
attr.weights = std::move(weights);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 1, 1, 4);
|
||||
|
@ -31,22 +31,22 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(LstmTest, Input2x2x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRefFloat32 prev_state;
|
||||
TensorRef<BHWC> prev_state;
|
||||
prev_state.type = DataType::FLOAT32;
|
||||
prev_state.ref = 1;
|
||||
prev_state.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRefFloat32 output_state;
|
||||
TensorRef<BHWC> output_state;
|
||||
output_state.type = DataType::FLOAT32;
|
||||
output_state.ref = 2;
|
||||
output_state.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRefFloat32 output_activation;
|
||||
TensorRef<BHWC> output_activation;
|
||||
output_activation.type = DataType::FLOAT32;
|
||||
output_activation.ref = 3;
|
||||
output_activation.shape = BHWC(1, 2, 2, 1);
|
||||
|
@ -31,17 +31,17 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(MaxUnpoolingTest, Kernel2x2Stride2x2) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRefFloat32 indices;
|
||||
TensorRef<BHWC> indices;
|
||||
indices.type = DataType::INT32;
|
||||
indices.ref = 1;
|
||||
indices.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 4, 4, 1);
|
||||
|
@ -31,12 +31,12 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(MulTest, Scalar) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
@ -51,12 +51,12 @@ TEST(MulTest, Scalar) {
|
||||
}
|
||||
|
||||
TEST(MulTest, Linear) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
@ -75,17 +75,17 @@ TEST(MulTest, Linear) {
|
||||
}
|
||||
|
||||
TEST(ApplyMaskTest, MaskChannel1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
TensorRefFloat32 mask;
|
||||
TensorRef<BHWC> mask;
|
||||
mask.type = DataType::FLOAT32;
|
||||
mask.ref = 1;
|
||||
mask.shape = BHWC(1, 1, 2, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
@ -99,17 +99,17 @@ TEST(ApplyMaskTest, MaskChannel1) {
|
||||
}
|
||||
|
||||
TEST(ApplyMaskTest, MaskChannelEqualsToInputChannel) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
TensorRefFloat32 mask;
|
||||
TensorRef<BHWC> mask;
|
||||
mask.type = DataType::FLOAT32;
|
||||
mask.ref = 1;
|
||||
mask.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
|
@ -34,12 +34,12 @@ namespace {
|
||||
|
||||
void TestPadOperation(const HWC& prepend, const HWC& append,
|
||||
const BHWC& output_shape, std::vector<float>&& expected) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = output_shape;
|
||||
|
@ -35,12 +35,12 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(PoolingTest, MaxKernel2x2Stride2x2WithIndices) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 4, 4, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
@ -70,12 +70,12 @@ TEST(PoolingTest, MaxKernel2x2Stride2x2WithIndices) {
|
||||
}
|
||||
|
||||
TEST(PoolingTest, MaxKernel2x2Stride2x2WithoutIndices) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 4, 4, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
@ -96,12 +96,12 @@ TEST(PoolingTest, MaxKernel2x2Stride2x2WithoutIndices) {
|
||||
}
|
||||
|
||||
TEST(PoolingTest, AverageKernel2x2Stride2x2) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 4, 4, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
|
@ -29,7 +29,7 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(PReluTest, LinearAlphaNoClip) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
@ -42,7 +42,7 @@ TEST(PReluTest, LinearAlphaNoClip) {
|
||||
alpha.data = {2};
|
||||
attr.alpha = std::move(alpha);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
@ -55,7 +55,7 @@ TEST(PReluTest, LinearAlphaNoClip) {
|
||||
}
|
||||
|
||||
TEST(PReluTest, LinearAlphaWithClip) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
@ -68,7 +68,7 @@ TEST(PReluTest, LinearAlphaWithClip) {
|
||||
alpha.data = {2};
|
||||
attr.alpha = std::move(alpha);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
@ -81,7 +81,7 @@ TEST(PReluTest, LinearAlphaWithClip) {
|
||||
}
|
||||
|
||||
TEST(PReluTest, 3DAlphaNoClip) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
@ -95,7 +95,7 @@ TEST(PReluTest, 3DAlphaNoClip) {
|
||||
alpha.data = {1, 2, 2, 2};
|
||||
attr.alpha = std::move(alpha);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
@ -107,7 +107,7 @@ TEST(PReluTest, 3DAlphaNoClip) {
|
||||
}
|
||||
|
||||
TEST(PReluTest, 3DAlphaWithClip) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
@ -121,7 +121,7 @@ TEST(PReluTest, 3DAlphaWithClip) {
|
||||
alpha.data = {1, 2, 2, 2};
|
||||
attr.alpha = std::move(alpha);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
|
@ -33,8 +33,8 @@ class ReluTest : public ::testing::Test {
|
||||
ReluTest() = default;
|
||||
~ReluTest() override = default;
|
||||
|
||||
TensorRefFloat32 GetTensorRef(int ref) {
|
||||
TensorRefFloat32 tensor_ref;
|
||||
TensorRef<BHWC> GetTensorRef(int ref) {
|
||||
TensorRef<BHWC> tensor_ref;
|
||||
tensor_ref.type = DataType::FLOAT32;
|
||||
tensor_ref.ref = ref;
|
||||
tensor_ref.shape = BHWC(1, 2, 2, 1);
|
||||
|
@ -31,12 +31,12 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(Reshape, 1x2x3To3x2x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 2, 3);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 3, 2, 1);
|
||||
@ -53,12 +53,12 @@ TEST(Reshape, 1x2x3To3x2x1) {
|
||||
}
|
||||
|
||||
TEST(Reshape, 3x1x2To2x1x3) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 3, 1, 2);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 1, 3);
|
||||
@ -75,12 +75,12 @@ TEST(Reshape, 3x1x2To2x1x3) {
|
||||
}
|
||||
|
||||
TEST(Reshape, 1x1x4To2x2x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 1, 4);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
@ -96,12 +96,12 @@ TEST(Reshape, 1x1x4To2x2x1) {
|
||||
}
|
||||
|
||||
TEST(Reshape, BatchIsUnsupported) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(4, 1, 1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
|
@ -31,12 +31,12 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(SliceTest, Identity) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
@ -54,12 +54,12 @@ TEST(SliceTest, Identity) {
|
||||
}
|
||||
|
||||
TEST(SliceTest, NegativeEnds) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
@ -77,12 +77,12 @@ TEST(SliceTest, NegativeEnds) {
|
||||
}
|
||||
|
||||
TEST(SliceTest, NegativeEndsNonZeroStarts) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
@ -100,12 +100,12 @@ TEST(SliceTest, NegativeEndsNonZeroStarts) {
|
||||
}
|
||||
|
||||
TEST(SliceTest, StridesByHeight) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 4, 1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 1, 1);
|
||||
@ -123,12 +123,12 @@ TEST(SliceTest, StridesByHeight) {
|
||||
}
|
||||
|
||||
TEST(SliceTest, StridesByWidth) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 4, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 2, 1);
|
||||
@ -146,12 +146,12 @@ TEST(SliceTest, StridesByWidth) {
|
||||
}
|
||||
|
||||
TEST(SliceTest, StridesByChannels) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 1, 4);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
|
@ -32,12 +32,12 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(SoftmaxTest, WorksForChannelsAxis) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
@ -53,12 +53,12 @@ TEST(SoftmaxTest, WorksForChannelsAxis) {
|
||||
}
|
||||
|
||||
TEST(SoftmaxTest, DoesNotWorkForHeightAxis) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
@ -75,12 +75,12 @@ TEST(SoftmaxTest, DoesNotWorkForHeightAxis) {
|
||||
}
|
||||
|
||||
TEST(SoftmaxTest, DoesNotWorkForWidthAxis) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
|
@ -37,8 +37,8 @@ namespace gpu {
|
||||
namespace gl {
|
||||
|
||||
SingleOpModel::SingleOpModel(Operation&& operation,
|
||||
const std::vector<TensorRefFloat32>& inputs,
|
||||
const std::vector<TensorRefFloat32>& outputs) {
|
||||
const std::vector<TensorRef<BHWC>>& inputs,
|
||||
const std::vector<TensorRef<BHWC>>& outputs) {
|
||||
auto node = graph_.NewNode();
|
||||
node->operation = std::move(operation);
|
||||
|
||||
|
@ -41,8 +41,8 @@ class SingleOpModel {
|
||||
public:
|
||||
SingleOpModel() = delete;
|
||||
SingleOpModel(Operation&& operation,
|
||||
const std::vector<TensorRefFloat32>& inputs,
|
||||
const std::vector<TensorRefFloat32>& outputs);
|
||||
const std::vector<TensorRef<BHWC>>& inputs,
|
||||
const std::vector<TensorRef<BHWC>>& outputs);
|
||||
|
||||
virtual ~SingleOpModel() = default;
|
||||
|
||||
|
@ -31,7 +31,7 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(TransposeConvTest, O2H2W1I1Stride1x1DAdjacent1x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
@ -54,7 +54,7 @@ TEST(TransposeConvTest, O2H2W1I1Stride1x1DAdjacent1x1) {
|
||||
attr.adjacent = HW(1, 1);
|
||||
attr.stride = HW(1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 2, 2, 2);
|
||||
@ -69,7 +69,7 @@ TEST(TransposeConvTest, O2H2W1I1Stride1x1DAdjacent1x1) {
|
||||
}
|
||||
|
||||
TEST(TransposeConvTest, O1H2W2I1Stride1x1Adjacent2x2) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 3, 3, 1);
|
||||
@ -92,7 +92,7 @@ TEST(TransposeConvTest, O1H2W2I1Stride1x1Adjacent2x2) {
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.stride = HW(1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
@ -106,7 +106,7 @@ TEST(TransposeConvTest, O1H2W2I1Stride1x1Adjacent2x2) {
|
||||
}
|
||||
|
||||
TEST(TransposeConvTest, O1H3W3I1Stride1x1Adjacent1x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
@ -129,7 +129,7 @@ TEST(TransposeConvTest, O1H3W3I1Stride1x1Adjacent1x1) {
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.stride = HW(1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
@ -143,7 +143,7 @@ TEST(TransposeConvTest, O1H3W3I1Stride1x1Adjacent1x1) {
|
||||
}
|
||||
|
||||
TEST(TransposeConvTest, O2H1W1I2Stride1x1Dilation1x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 1, 2);
|
||||
@ -166,7 +166,7 @@ TEST(TransposeConvTest, O2H1W1I2Stride1x1Dilation1x1) {
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.stride = HW(1, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 2, 1, 2);
|
||||
@ -180,7 +180,7 @@ TEST(TransposeConvTest, O2H1W1I2Stride1x1Dilation1x1) {
|
||||
}
|
||||
|
||||
TEST(TransposeConvTest, O1H1W1I1Stride2x2Dilation1x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 3, 3, 1);
|
||||
@ -204,7 +204,7 @@ TEST(TransposeConvTest, O1H1W1I1Stride2x2Dilation1x1) {
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.stride = HW(2, 2);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 3;
|
||||
output.shape = BHWC(1, 1, 1, 1);
|
||||
|
@ -31,12 +31,12 @@ namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(UpsamplingBilinearTest, 1x1x2To2x2x2) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 1, 2);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 2);
|
||||
@ -56,12 +56,12 @@ TEST(UpsamplingBilinearTest, 1x1x2To2x2x2) {
|
||||
}
|
||||
|
||||
TEST(UpsamplingBilinearTest, 1x2x1To1x4x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 2, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 4, 1);
|
||||
@ -80,12 +80,12 @@ TEST(UpsamplingBilinearTest, 1x2x1To1x4x1) {
|
||||
}
|
||||
|
||||
TEST(UpsamplingBilinearTest, 2x2x1To4x4x1) {
|
||||
TensorRefFloat32 input;
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
TensorRefFloat32 output;
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 4, 4, 1);
|
||||
|
@ -32,7 +32,7 @@ Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor,
|
||||
return CreateReadOnlyShaderStorageBuffer<float>(transposed, gl_buffer);
|
||||
}
|
||||
|
||||
Status CreatePHWC4BufferFromTensorRef(const TensorRefFloat32& tensor_ref,
|
||||
Status CreatePHWC4BufferFromTensorRef(const TensorRef<BHWC>& tensor_ref,
|
||||
GlBuffer* gl_buffer) {
|
||||
return CreateReadWriteShaderStorageBuffer<float>(
|
||||
GetElementsSizeForPHWC4(tensor_ref.shape), gl_buffer);
|
||||
|
@ -72,7 +72,7 @@ Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor,
|
||||
|
||||
// Creates read-write buffer for the given tensor shape, where data layout is
|
||||
// supposed to be PHWC4.
|
||||
Status CreatePHWC4BufferFromTensorRef(const TensorRefFloat32& tensor_ref,
|
||||
Status CreatePHWC4BufferFromTensorRef(const TensorRef<BHWC>& tensor_ref,
|
||||
GlBuffer* gl_buffer);
|
||||
|
||||
// Copies data from a buffer that holds data in PHWC4 layout to the given
|
||||
|
@ -148,7 +148,7 @@ class Delegate {
|
||||
|
||||
// TODO(impjdi): Remove code duplication.
|
||||
auto values = graph.values();
|
||||
auto find_value = [&](int tensor_index) -> Value<TensorRefFloat32>* {
|
||||
auto find_value = [&](int tensor_index) -> Value<TensorRef<BHWC>>* {
|
||||
for (auto value : values) {
|
||||
if (value->tensor.ref == tensor_index) return value;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user