Some tests fixes.

PiperOrigin-RevId: 310170770
Change-Id: Ie957e6f9ac3eee6be3fa4db7a8bbdaf57e15f508
This commit is contained in:
Raman Sarokin 2020-05-06 09:49:37 -07:00 committed by TensorFlower Gardener
parent 1a18b711ae
commit c43c69e388
2 changed files with 9 additions and 6 deletions

View File

@ -57,12 +57,12 @@ kernel void FunctionName(device TYPE* const src_buffer[[buffer(0)]],
NSDictionary* macrosFloat4 = @{@"TYPE" : @"float4"};
status = CreateComputeProgram(device, code, functionName, macrosFloat4, &program);
XCTAssertTrue(status.ok(), @"%s", std::string(status.messasge()).c_str());
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
XCTAssertNotNil(program);
NSDictionary* macrosHalf4 = @{@"TYPE" : @"half4"};
status = CreateComputeProgram(device, code, functionName, macrosHalf4, &program);
XCTAssertTrue(status.ok(), @"%s", std::string(status.messasge()).c_str());
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
XCTAssertNotNil(program);
// This compilation is intended to be incorrect

View File

@ -159,7 +159,8 @@ static std::vector<ComputeTaskDescriptorPtr> Add2Linkable(int id, ValueId input_
std::vector<ComputeTaskDescriptorPtr> descriptors;
descriptors.push_back(ComputeTaskDescriptorPtr(new ComputeTaskDescriptor({
id,
true, // Is linkable?
true, // linkable
true, // associative_op
R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, device FLT4* const buffer2) {
return value + buffer2[linear_index];
}
@ -250,12 +251,14 @@ static std::vector<ComputeTaskDescriptorPtr> Add2Linkable(int id, ValueId input_
- (void)testAddOperationFused {
auto graph = Add(1, 1, 3);
auto graph2 = Add2Linkable(2, 2, 3, 4);
auto graph2 = Add(1, 2, 4);
auto graph3 = Add2Linkable(2, 4, 3, 5);
graph.insert(graph.end(), graph2.begin(), graph2.end());
graph.insert(graph.end(), graph3.begin(), graph3.end());
std::vector<ComputeTaskDescriptorPtr> model;
auto status = ValidateOptimizeModel({1, 2}, {4}, graph, &model);
auto status = ValidateOptimizeModel({1, 2}, {5}, graph, &model);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
XCTAssertTrue(model.size() == 1, @"Not fused, more than one task descriptor.");
XCTAssertTrue(model.size() <= 2, @"Not fused, more than two task descriptors.");
}
- (void)testBinaryOperationSuccess {