Making sure that recently introduced offset argument for the InitializeTableFromTextFile op is handled in a backwards / forwards compatible manner.

PiperOrigin-RevId: 358416520
Change-Id: I8b9637c5d9707097738de134e64b2cb72c73022a
This commit is contained in:
Jiri Simsa 2021-02-19 09:06:50 -08:00 committed by TensorFlower Gardener
parent e661958293
commit 2401461cee
2 changed files with 18 additions and 9 deletions

View File

@ -105,7 +105,9 @@ class InitializeTableFromTextFileOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("offset", &offset_));
if (ctx->HasAttr("offset")) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("offset", &offset_));
}
string delimiter;
OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter));
OP_REQUIRES(ctx, delimiter.size() == 1,
@ -155,7 +157,7 @@ class InitializeTableFromTextFileOp : public OpKernel {
char delimiter_;
int64 key_index_;
int64 value_index_;
int64 offset_;
int64 offset_ = 0;
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp);
};

View File

@ -24,6 +24,7 @@ import uuid
import six
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -45,7 +46,7 @@ from tensorflow.python.training.saver import BaseSaverBuilder
# pylint: enable=wildcard-import
from tensorflow.python.training.tracking import base as trackable_base
from tensorflow.python.training.tracking import tracking as trackable
from tensorflow.python.util import compat
from tensorflow.python.util import compat as compat_util
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@ -749,10 +750,16 @@ class TextFileInitializer(TableInitializerBase):
with ops.name_scope(self._name, "text_file_init", (table.resource_handle,)):
filename = ops.convert_to_tensor(
self._filename, dtypes.string, name="asset_filepath")
init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
table.resource_handle, filename, self._key_index, self._value_index,
-1 if self._vocab_size is None else self._vocab_size, self._delimiter,
self._offset)
if self._offset != 0 or compat.forward_compatible(2021, 3, 18):
init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
table.resource_handle, filename, self._key_index, self._value_index,
-1 if self._vocab_size is None else self._vocab_size,
self._delimiter, self._offset)
else:
init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
table.resource_handle, filename, self._key_index, self._value_index,
-1 if self._vocab_size is None else self._vocab_size,
self._delimiter)
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
# If the filename tensor is anything other than a string constant (e.g.,
# if it is a placeholder) then it does not make sense to track it as an
@ -915,8 +922,8 @@ class StrongHashSpec(HasherSpec):
if len(key) != 2:
raise ValueError("key must have size 2, got %s." % len(key))
if not isinstance(key[0], compat.integral_types) or not isinstance(
key[1], compat.integral_types):
if not isinstance(key[0], compat_util.integral_types) or not isinstance(
key[1], compat_util.integral_types):
raise TypeError("Invalid key %s. Must be unsigned integer values." % key)
return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)