Make select quantizable
PiperOrigin-RevId: 355053742 Change-Id: Iba3448010822e1693e8acb1c470d2db9997e437e
This commit is contained in:
parent
f5970da0fc
commit
d981631696
@ -32,11 +32,19 @@ def make_where_tests(options):
|
|||||||
"input_dtype": [tf.float32, tf.int32],
|
"input_dtype": [tf.float32, tf.int32],
|
||||||
"input_shape_set": [([1, 2, 3, 4], [1, 2, 3, 4]),],
|
"input_shape_set": [([1, 2, 3, 4], [1, 2, 3, 4]),],
|
||||||
"use_where_v2": [False, True],
|
"use_where_v2": [False, True],
|
||||||
|
"fully_quantize": [False],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"input_dtype": [tf.float32, tf.int32],
|
"input_dtype": [tf.float32, tf.int32],
|
||||||
"input_shape_set": [([], []),],
|
"input_shape_set": [([], []),],
|
||||||
"use_where_v2": [],
|
"use_where_v2": [],
|
||||||
|
"fully_quantize": [False],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"input_dtype": [tf.float32],
|
||||||
|
"input_shape_set": [([1, 2, 3, 4], [1, 2, 3, 4]), ([], []),],
|
||||||
|
"use_where_v2": [False, True],
|
||||||
|
"fully_quantize": [True],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -47,6 +55,13 @@ def make_where_tests(options):
|
|||||||
"input_dtype": [tf.float32, tf.int32],
|
"input_dtype": [tf.float32, tf.int32],
|
||||||
"input_shape_set": [([8, 7, 6, 5, 4, 3, 2, 1], [4, 3, 2, 1]),],
|
"input_shape_set": [([8, 7, 6, 5, 4, 3, 2, 1], [4, 3, 2, 1]),],
|
||||||
"use_where_v2": [True],
|
"use_where_v2": [True],
|
||||||
|
"fully_quantize": [False],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"input_dtype": [tf.float32],
|
||||||
|
"input_shape_set": [([8, 7, 6, 5, 4, 3, 2, 1], [4, 3, 2, 1]),],
|
||||||
|
"use_where_v2": [True],
|
||||||
|
"fully_quantize": [True],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -67,9 +82,11 @@ def make_where_tests(options):
|
|||||||
|
|
||||||
def build_inputs(parameters, sess, inputs, outputs):
|
def build_inputs(parameters, sess, inputs, outputs):
|
||||||
input_value1 = create_tensor_data(parameters["input_dtype"],
|
input_value1 = create_tensor_data(parameters["input_dtype"],
|
||||||
parameters["input_shape_set"][0])
|
parameters["input_shape_set"][0],
|
||||||
|
min_value=-1, max_value=1)
|
||||||
input_value2 = create_tensor_data(parameters["input_dtype"],
|
input_value2 = create_tensor_data(parameters["input_dtype"],
|
||||||
parameters["input_shape_set"][1])
|
parameters["input_shape_set"][1],
|
||||||
|
min_value=-1, max_value=1)
|
||||||
return [input_value1, input_value2], sess.run(
|
return [input_value1, input_value2], sess.run(
|
||||||
outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
|
outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
|
||||||
|
|
||||||
|
@ -907,6 +907,12 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
|
|||||||
property.restrict_same_input_output_scale = true;
|
property.restrict_same_input_output_scale = true;
|
||||||
property.version = 3;
|
property.version = 3;
|
||||||
break;
|
break;
|
||||||
|
case BuiltinOperator_SELECT:
|
||||||
|
property.inputs = {{1, {}}, {2, {}}};
|
||||||
|
property.outputs = {{0, {}}};
|
||||||
|
property.restrict_same_input_output_scale = true;
|
||||||
|
property.version = 1;
|
||||||
|
break;
|
||||||
case BuiltinOperator_SHAPE:
|
case BuiltinOperator_SHAPE:
|
||||||
property.inputs = {{0, {}}};
|
property.inputs = {{0, {}}};
|
||||||
// Shape has no quantizable output.
|
// Shape has no quantizable output.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user