Add a method for choosing a random TF dtype while ignoring a list of invalid ones.

PiperOrigin-RevId: 356317131
Change-Id: I4ef98c116d92d91e3fa8bfcb96b3f56a8b235469
This commit is contained in:
Amit Patankar 2021-02-08 11:49:48 -08:00 committed by TensorFlower Gardener
parent f798728d37
commit 05fb0863d3

View File

@ -126,13 +126,20 @@ class FuzzingHelper(object):
else:
return self.get_float_list(min_length, max_length)
def get_tf_dtype(self):
"""Return a random tensorflow type.
def get_tf_dtype(self, allowed_set=None):
"""Return a random tensorflow dtype.
Args:
allowed_set: An allowlisted set of dtypes to choose from instead of all of
them.
Returns:
A random type from the list containing all TensorFlow types.
"""
index = self.get_int(0, len(_TF_DTYPES) - 1)
if allowed_set:
index = self.get_int(0, len(allowed_set) - 1)
else:
index = self.get_int(0, len(_TF_DTYPES) - 1)
return _TF_DTYPES[index]
def get_string(self, byte_count=_MAX_INT):