iOS Metal delegate: squared diff operation tests added.

PiperOrigin-RevId: 272055201
This commit is contained in:
A. Unique TensorFlower 2019-09-30 13:39:34 -07:00 committed by TensorFlower Gardener
parent 506336dd8b
commit 00979a1a95
2 changed files with 16 additions and 0 deletions

View File

@ -163,6 +163,7 @@ OperationType OperationTypeFromString(const std::string& name) {
{"softmax", OperationType::SOFTMAX},
{"sqrt", OperationType::SQRT},
{"square", OperationType::SQUARE},
{"squared_diff", OperationType::SQUARED_DIFF},
{"subtract", OperationType::SUB},
{"tanh", OperationType::TANH},
{"upsample_2d", OperationType::UPSAMPLE_2D},

View File

@ -198,6 +198,21 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
}
- (void)testSquaredDiff {
OperationType op_type = OperationType::SQUARED_DIFF;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
XCTAssertTrue(model.PopulateTensor(0, {0.0, 2.0, 2.0, 4.0}));
XCTAssertTrue(model.PopulateTensor(1, {1.0, 1.0, 5.0, 4.0}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
status = CompareVectors({1.0, 1.0, 9.0, 0.0}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", status.ToString().c_str());
}
- (void)testSub {
OperationType op_type = OperationType::SUB;
const BHWC shape(1, 2, 2, 1);