iOS Metal delegate: squared diff operation tests added.
PiperOrigin-RevId: 272055201
This commit is contained in:
parent
506336dd8b
commit
00979a1a95
@ -163,6 +163,7 @@ OperationType OperationTypeFromString(const std::string& name) {
|
||||
{"softmax", OperationType::SOFTMAX},
|
||||
{"sqrt", OperationType::SQRT},
|
||||
{"square", OperationType::SQUARE},
|
||||
{"squared_diff", OperationType::SQUARED_DIFF},
|
||||
{"subtract", OperationType::SUB},
|
||||
{"tanh", OperationType::TANH},
|
||||
{"upsample_2d", OperationType::UPSAMPLE_2D},
|
||||
|
@ -198,6 +198,21 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
|
||||
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
|
||||
}
|
||||
|
||||
- (void)testSquaredDiff {
|
||||
OperationType op_type = OperationType::SQUARED_DIFF;
|
||||
const BHWC shape(1, 2, 2, 1);
|
||||
SingleOpModel model(
|
||||
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
|
||||
/*outputs=*/{GetTensorRef(2, shape)});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {0.0, 2.0, 2.0, 4.0}));
|
||||
XCTAssertTrue(model.PopulateTensor(1, {1.0, 1.0, 5.0, 4.0}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
|
||||
status = CompareVectors({1.0, 1.0, 9.0, 0.0}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
|
||||
}
|
||||
|
||||
- (void)testSub {
|
||||
OperationType op_type = OperationType::SUB;
|
||||
const BHWC shape(1, 2, 2, 1);
|
||||
|
Loading…
x
Reference in New Issue
Block a user