Allow file-based initializers with integer values.
PiperOrigin-RevId: 358292431 Change-Id: Id27fc2dc0be23ef328503c73394bcbe1f0de59bc
This commit is contained in:
parent
c29e9f25e7
commit
be625d2fe8
@ -375,9 +375,11 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
||||
"Value index for line number requires table value dtype of int64, got ",
|
||||
DataTypeString(table->value_dtype()));
|
||||
}
|
||||
if (value_index == kWholeLine && value_dtype != DT_STRING) {
|
||||
if (value_index == kWholeLine && !DataTypeIsInteger(value_dtype) &&
|
||||
value_dtype != DT_STRING) {
|
||||
return errors::InvalidArgument(
|
||||
"Value index for whole line requires table value dtype of string, got ",
|
||||
"Value index for whole line requires table value dtype of integer or "
|
||||
"string, got ",
|
||||
DataTypeString(table->value_dtype()));
|
||||
}
|
||||
|
||||
|
@ -659,7 +659,7 @@ class TextFileInitializer(TableInitializerBase):
|
||||
- TextFileIndex.LINE_NUMBER means use the line number starting from zero,
|
||||
expects data type int64.
|
||||
- TextFileIndex.WHOLE_LINE means use the whole line content, expects data
|
||||
type string.
|
||||
type string or int64.
|
||||
- A value >=0 means use the index (starting at zero) of the split line based
|
||||
on `delimiter`.
|
||||
|
||||
@ -712,9 +712,11 @@ class TextFileInitializer(TableInitializerBase):
|
||||
if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
|
||||
raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
|
||||
(dtypes.int64, value_dtype))
|
||||
if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string:
|
||||
raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
|
||||
(dtypes.string, value_dtype))
|
||||
if ((value_index == TextFileIndex.WHOLE_LINE) and
|
||||
(not value_dtype.is_integer) and (value_dtype != dtypes.string)):
|
||||
raise ValueError(
|
||||
"Signature mismatch. Values must be integer or string, got %s." %
|
||||
(value_dtype))
|
||||
|
||||
if (vocab_size is not None) and (vocab_size <= 0):
|
||||
raise ValueError("Invalid vocab_size %s." % vocab_size)
|
||||
@ -795,7 +797,7 @@ class TextFileStringTableInitializer(TextFileInitializer):
|
||||
- TextFileIndex.LINE_NUMBER means use the line number starting from zero,
|
||||
expects data type int64.
|
||||
- TextFileIndex.WHOLE_LINE means use the whole line content, expects data
|
||||
type string.
|
||||
type string or int64.
|
||||
- A value >=0 means use the index (starting at zero) of the split line based
|
||||
on `delimiter`.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user