diff --git a/tensorflow/security/fuzzing/python_fuzzing.py b/tensorflow/security/fuzzing/python_fuzzing.py index 1cabed46c41..42973a3f654 100644 --- a/tensorflow/security/fuzzing/python_fuzzing.py +++ b/tensorflow/security/fuzzing/python_fuzzing.py @@ -126,13 +126,20 @@ class FuzzingHelper(object): else: return self.get_float_list(min_length, max_length) - def get_tf_dtype(self): - """Return a random tensorflow type. + def get_tf_dtype(self, allowed_set=None): + """Return a random tensorflow dtype. + + Args: + allowed_set: An allowlisted set of dtypes to choose from instead of all of + them. Returns: A random type from the list containing all TensorFlow types. """ - index = self.get_int(0, len(_TF_DTYPES) - 1) + if allowed_set: + index = self.get_int(0, len(allowed_set) - 1) + else: + index = self.get_int(0, len(_TF_DTYPES) - 1) return _TF_DTYPES[index] def get_string(self, byte_count=_MAX_INT):