Add a method for choosing a random TF dtype while ignoring a list of invalid ones.
PiperOrigin-RevId: 356317131 Change-Id: I4ef98c116d92d91e3fa8bfcb96b3f56a8b235469
This commit is contained in:
parent
f798728d37
commit
05fb0863d3
@ -126,12 +126,19 @@ class FuzzingHelper(object):
|
|||||||
else:
|
else:
|
||||||
return self.get_float_list(min_length, max_length)
|
return self.get_float_list(min_length, max_length)
|
||||||
|
|
||||||
def get_tf_dtype(self):
|
def get_tf_dtype(self, allowed_set=None):
|
||||||
"""Return a random tensorflow type.
|
"""Return a random tensorflow dtype.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_set: An allowlisted set of dtypes to choose from instead of all of
|
||||||
|
them.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A random type from the list containing all TensorFlow types.
|
A random type from the list containing all TensorFlow types.
|
||||||
"""
|
"""
|
||||||
|
if allowed_set:
|
||||||
|
index = self.get_int(0, len(allowed_set) - 1)
|
||||||
|
else:
|
||||||
index = self.get_int(0, len(_TF_DTYPES) - 1)
|
index = self.get_int(0, len(_TF_DTYPES) - 1)
|
||||||
return _TF_DTYPES[index]
|
return _TF_DTYPES[index]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user