Internal change
PiperOrigin-RevId: 340148964 Change-Id: Ib0aa3d63285ee5a74b3ca002f4a220ce8919cf79
This commit is contained in:
parent
b5db1ca430
commit
1d5d379f19
@ -291,10 +291,13 @@ std::vector<string> UnarchiveAndFindTestNames(const string& zip_file,
|
|||||||
class OpsTest : public ::testing::TestWithParam<string> {};
|
class OpsTest : public ::testing::TestWithParam<string> {};
|
||||||
|
|
||||||
TEST_P(OpsTest, RunZipTests) {
|
TEST_P(OpsTest, RunZipTests) {
|
||||||
string test_path = GetParam();
|
string test_path_and_label = GetParam();
|
||||||
|
size_t end_pos = test_path_and_label.find(" ");
|
||||||
|
string test_path = test_path_and_label.substr(0, end_pos);
|
||||||
|
string label = test_path_and_label.substr(end_pos + 1);
|
||||||
string tflite_test_case = test_path + "_tests.txt";
|
string tflite_test_case = test_path + "_tests.txt";
|
||||||
string tflite_dir = test_path.substr(0, test_path.find_last_of("/"));
|
string tflite_dir = test_path.substr(0, test_path.find_last_of("/"));
|
||||||
string test_name = test_path.substr(test_path.find_last_of('/'));
|
string test_name = label.substr(label.find_last_of('/'));
|
||||||
|
|
||||||
std::ifstream tflite_stream(tflite_test_case);
|
std::ifstream tflite_stream(tflite_test_case);
|
||||||
ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
|
ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
|
||||||
@ -305,7 +308,7 @@ TEST_P(OpsTest, RunZipTests) {
|
|||||||
|
|
||||||
auto quantized_tests_error = GetQuantizeTestsError();
|
auto quantized_tests_error = GetQuantizeTestsError();
|
||||||
bool fully_quantize = false;
|
bool fully_quantize = false;
|
||||||
if (test_path.find("fully_quantize=True") != std::string::npos) {
|
if (label.find("fully_quantize=True") != std::string::npos) {
|
||||||
for (const auto& p : quantized_tests_error) {
|
for (const auto& p : quantized_tests_error) {
|
||||||
if (RE2::PartialMatch(test_name, p.first)) {
|
if (RE2::PartialMatch(test_name, p.first)) {
|
||||||
test_driver.SetQuantizationErrorMultiplier(p.second);
|
test_driver.SetQuantizationErrorMultiplier(p.second);
|
||||||
|
@ -40,6 +40,7 @@ def make_conv_activation_tests(activation_op):
|
|||||||
"constant_filter": [True, False],
|
"constant_filter": [True, False],
|
||||||
"channel_multiplier": [1, 2],
|
"channel_multiplier": [1, 2],
|
||||||
"fully_quantize": [False],
|
"fully_quantize": [False],
|
||||||
|
"quant_16x8": [False],
|
||||||
"dynamic_range_quantize": [False],
|
"dynamic_range_quantize": [False],
|
||||||
},
|
},
|
||||||
# TODO(b/134702301): The fully_quantize param is just ignored by the
|
# TODO(b/134702301): The fully_quantize param is just ignored by the
|
||||||
@ -47,14 +48,15 @@ def make_conv_activation_tests(activation_op):
|
|||||||
# these tests or handle it properly in the mlir_convert() function.
|
# these tests or handle it properly in the mlir_convert() function.
|
||||||
{
|
{
|
||||||
"input_shape": [[1, 3, 4, 3], [4, 6, 6, 1]],
|
"input_shape": [[1, 3, 4, 3], [4, 6, 6, 1]],
|
||||||
"filter_shape": [[1, 1], [2, 3], [3, 3]],
|
"filter_shape": [[1, 1], [2, 3]],
|
||||||
"strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
|
"strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
|
||||||
"dilations": [[1, 1, 1, 1], [1, 3, 2, 1], [1, 2, 2, 1]],
|
"dilations": [[1, 1, 1, 1], [1, 3, 2, 1]],
|
||||||
"padding": ["SAME", "VALID"],
|
"padding": ["SAME", "VALID"],
|
||||||
"data_format": ["NHWC"], # TODO(aselle): NCHW would be good
|
"data_format": ["NHWC"], # TODO(aselle): NCHW would be good
|
||||||
"constant_filter": [True],
|
"constant_filter": [True],
|
||||||
"channel_multiplier": [1, 2],
|
"channel_multiplier": [1, 2],
|
||||||
"fully_quantize": [True],
|
"fully_quantize": [True],
|
||||||
|
"quant_16x8": [False, True],
|
||||||
"dynamic_range_quantize": [False],
|
"dynamic_range_quantize": [False],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -67,6 +69,7 @@ def make_conv_activation_tests(activation_op):
|
|||||||
"constant_filter": [True],
|
"constant_filter": [True],
|
||||||
"channel_multiplier": [1, 2],
|
"channel_multiplier": [1, 2],
|
||||||
"fully_quantize": [False],
|
"fully_quantize": [False],
|
||||||
|
"quant_16x8": [False],
|
||||||
"dynamic_range_quantize": [True],
|
"dynamic_range_quantize": [True],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
@ -123,7 +126,7 @@ def make_conv_activation_tests(activation_op):
|
|||||||
test_parameters,
|
test_parameters,
|
||||||
build_graph,
|
build_graph,
|
||||||
build_inputs,
|
build_inputs,
|
||||||
expected_tf_failures=60)
|
expected_tf_failures=48)
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
@ -342,6 +342,7 @@ def make_zip_of_tests(options,
|
|||||||
if options.multi_gen_state:
|
if options.multi_gen_state:
|
||||||
label_base_path = options.multi_gen_state.label_base_path
|
label_base_path = options.multi_gen_state.label_base_path
|
||||||
|
|
||||||
|
i = 1
|
||||||
for parameters in test_parameters:
|
for parameters in test_parameters:
|
||||||
keys = parameters.keys()
|
keys = parameters.keys()
|
||||||
for curr in itertools.product(*parameters.values()):
|
for curr in itertools.product(*parameters.values()):
|
||||||
@ -349,6 +350,8 @@ def make_zip_of_tests(options,
|
|||||||
"%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", ""))
|
"%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", ""))
|
||||||
if label[0] == "/":
|
if label[0] == "/":
|
||||||
label = label[1:]
|
label = label[1:]
|
||||||
|
zip_path_label = label_base_path.replace(".zip", "_") + str(i)
|
||||||
|
i += 1
|
||||||
if label in processed_labels:
|
if label in processed_labels:
|
||||||
# Do not populate data for the same label more than once. It will cause
|
# Do not populate data for the same label more than once. It will cause
|
||||||
# errors when unzipping.
|
# errors when unzipping.
|
||||||
@ -397,13 +400,14 @@ def make_zip_of_tests(options,
|
|||||||
|
|
||||||
return input_values, output_values
|
return input_values, output_values
|
||||||
|
|
||||||
def build_example(label, param_dict_real):
|
def build_example(label, param_dict_real, zip_path_label):
|
||||||
"""Build the model with parameter values set in param_dict_real.
|
"""Build the model with parameter values set in param_dict_real.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
label: Label of the model (i.e. the filename in the zip).
|
label: Label of the model
|
||||||
param_dict_real: Parameter dictionary (arguments to the factories
|
param_dict_real: Parameter dictionary (arguments to the factories
|
||||||
make_graph and make_test_inputs)
|
make_graph and make_test_inputs)
|
||||||
|
zip_path_label: Filename in the zip
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tflite_model_binary, report) where tflite_model_binary is the
|
(tflite_model_binary, report) where tflite_model_binary is the
|
||||||
@ -466,7 +470,7 @@ def make_zip_of_tests(options,
|
|||||||
report["toco_log"] = toco_log
|
report["toco_log"] = toco_log
|
||||||
|
|
||||||
if options.save_graphdefs:
|
if options.save_graphdefs:
|
||||||
archive.writestr(label + ".pbtxt",
|
archive.writestr(zip_path_label + ".pbtxt",
|
||||||
text_format.MessageToString(graph_def),
|
text_format.MessageToString(graph_def),
|
||||||
zipfile.ZIP_DEFLATED)
|
zipfile.ZIP_DEFLATED)
|
||||||
|
|
||||||
@ -475,25 +479,25 @@ def make_zip_of_tests(options,
|
|||||||
# Set proper min max values according to input dtype.
|
# Set proper min max values according to input dtype.
|
||||||
baseline_inputs, baseline_outputs = generate_inputs_outputs(
|
baseline_inputs, baseline_outputs = generate_inputs_outputs(
|
||||||
tflite_model_binary, min_value=0, max_value=255)
|
tflite_model_binary, min_value=0, max_value=255)
|
||||||
archive.writestr(label + ".bin", tflite_model_binary,
|
archive.writestr(zip_path_label + ".bin", tflite_model_binary,
|
||||||
zipfile.ZIP_DEFLATED)
|
zipfile.ZIP_DEFLATED)
|
||||||
example = {"inputs": baseline_inputs, "outputs": baseline_outputs}
|
example = {"inputs": baseline_inputs, "outputs": baseline_outputs}
|
||||||
|
|
||||||
example_fp = StringIO()
|
example_fp = StringIO()
|
||||||
write_examples(example_fp, [example])
|
write_examples(example_fp, [example])
|
||||||
archive.writestr(label + ".inputs", example_fp.getvalue(),
|
archive.writestr(zip_path_label + ".inputs", example_fp.getvalue(),
|
||||||
zipfile.ZIP_DEFLATED)
|
zipfile.ZIP_DEFLATED)
|
||||||
|
|
||||||
example_fp2 = StringIO()
|
example_fp2 = StringIO()
|
||||||
write_test_cases(example_fp2, label + ".bin", [example])
|
write_test_cases(example_fp2, zip_path_label + ".bin", [example])
|
||||||
archive.writestr(label + "_tests.txt", example_fp2.getvalue(),
|
archive.writestr(zip_path_label + "_tests.txt",
|
||||||
zipfile.ZIP_DEFLATED)
|
example_fp2.getvalue(), zipfile.ZIP_DEFLATED)
|
||||||
|
|
||||||
zip_manifest.append(label + "\n")
|
zip_manifest.append(zip_path_label + " " + label + "\n")
|
||||||
|
|
||||||
return tflite_model_binary, report
|
return tflite_model_binary, report
|
||||||
|
|
||||||
_, report = build_example(label, param_dict)
|
_, report = build_example(label, param_dict, zip_path_label)
|
||||||
|
|
||||||
if report["toco"] == report_lib.FAILED:
|
if report["toco"] == report_lib.FAILED:
|
||||||
ignore_error = False
|
ignore_error = False
|
||||||
|
Loading…
x
Reference in New Issue
Block a user