diff --git a/tensorflow/security/fuzzing/python_fuzzing.py b/tensorflow/security/fuzzing/python_fuzzing.py
index 1cabed46c41..42973a3f654 100644
--- a/tensorflow/security/fuzzing/python_fuzzing.py
+++ b/tensorflow/security/fuzzing/python_fuzzing.py
@@ -126,13 +126,20 @@ class FuzzingHelper(object):
     else:
       return self.get_float_list(min_length, max_length)
 
-  def get_tf_dtype(self):
-    """Return a random tensorflow type.
+  def get_tf_dtype(self, allowed_set=None):
+    """Return a random tensorflow dtype.
+
+    Args:
+      allowed_set: An allowlisted set of dtypes to choose from instead of all of
+      them.
 
     Returns:
       A random type from the list containing all TensorFlow types.
     """
-    index = self.get_int(0, len(_TF_DTYPES) - 1)
+    if allowed_set:
+      index = self.get_int(0, len(allowed_set) - 1)
+    else:
+      index = self.get_int(0, len(_TF_DTYPES) - 1)
     return _TF_DTYPES[index]
 
   def get_string(self, byte_count=_MAX_INT):