Add a method for choosing a random TF dtype while ignoring a list of invalid ones.
PiperOrigin-RevId: 356317131 Change-Id: I4ef98c116d92d91e3fa8bfcb96b3f56a8b235469
This commit is contained in:
parent
f798728d37
commit
05fb0863d3
@ -126,13 +126,20 @@ class FuzzingHelper(object):
|
||||
else:
|
||||
return self.get_float_list(min_length, max_length)
|
||||
|
||||
def get_tf_dtype(self):
|
||||
"""Return a random tensorflow type.
|
||||
def get_tf_dtype(self, allowed_set=None):
|
||||
"""Return a random tensorflow dtype.
|
||||
|
||||
Args:
|
||||
allowed_set: An allowlisted set of dtypes to choose from instead of all of
|
||||
them.
|
||||
|
||||
Returns:
|
||||
A random type from the list containing all TensorFlow types.
|
||||
"""
|
||||
index = self.get_int(0, len(_TF_DTYPES) - 1)
|
||||
if allowed_set:
|
||||
index = self.get_int(0, len(allowed_set) - 1)
|
||||
else:
|
||||
index = self.get_int(0, len(_TF_DTYPES) - 1)
|
||||
return _TF_DTYPES[index]
|
||||
|
||||
def get_string(self, byte_count=_MAX_INT):
|
||||
|
Loading…
Reference in New Issue
Block a user