diff --git a/tensorflow/security/fuzzing/consume_leading_digits_fuzz.cc b/tensorflow/security/fuzzing/consume_leading_digits_fuzz.cc index d49bc1f2110..2c458bb988a 100644 --- a/tensorflow/security/fuzzing/consume_leading_digits_fuzz.cc +++ b/tensorflow/security/fuzzing/consume_leading_digits_fuzz.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <fuzzer/FuzzedDataProvider.h> + #include <cstdint> #include <cstdlib> @@ -23,16 +25,22 @@ limitations under the License. namespace { extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { - uint8_t *byte_data = const_cast<uint8_t *>(data); - char *char_data = reinterpret_cast<char *>(byte_data); + FuzzedDataProvider fuzzed_data(data, size); - tensorflow::StringPiece sp(char_data, size); - tensorflow::uint64 val; + while (fuzzed_data.remaining_bytes() > 0) { + std::string s = fuzzed_data.ConsumeRandomLengthString(25); + tensorflow::StringPiece sp(s); + tensorflow::uint64 val; - const bool leading_digits = - tensorflow::str_util::ConsumeLeadingDigits(&sp, &val); - if (leading_digits) { - assert(val >= 0); + const bool leading_digits = + tensorflow::str_util::ConsumeLeadingDigits(&sp, &val); + const char lead_char_consume_digits = *(sp.data()); + if (leading_digits) { + if (lead_char_consume_digits >= '0') { + assert(lead_char_consume_digits > '9'); + } + assert(val >= 0); + } } return 0;