From 2401461cee208f6a734e09a69ed9a3852497a502 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Fri, 19 Feb 2021 09:06:50 -0800 Subject: [PATCH] Making sure that recently introduced `offset` argument for the `InitializeTableFromTextFile` op is handled in a backwards / forwards compatible manner. PiperOrigin-RevId: 358416520 Change-Id: I8b9637c5d9707097738de134e64b2cb72c73022a --- .../core/kernels/lookup_table_init_op.cc | 6 ++++-- tensorflow/python/ops/lookup_ops.py | 21 ++++++++++++------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc index d21ac547db2..9b9e3cd1207 100644 --- a/tensorflow/core/kernels/lookup_table_init_op.cc +++ b/tensorflow/core/kernels/lookup_table_init_op.cc @@ -105,7 +105,9 @@ class InitializeTableFromTextFileOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("offset", &offset_)); + if (ctx->HasAttr("offset")) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("offset", &offset_)); + } string delimiter; OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter)); OP_REQUIRES(ctx, delimiter.size() == 1, @@ -155,7 +157,7 @@ class InitializeTableFromTextFileOp : public OpKernel { char delimiter_; int64 key_index_; int64 value_index_; - int64 offset_; + int64 offset_ = 0; TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp); }; diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index e541a00bd96..1e7cabd87fe 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -24,6 +24,7 @@ import uuid import six +from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -45,7 +46,7 @@ from tensorflow.python.training.saver import BaseSaverBuilder # pylint: enable=wildcard-import from tensorflow.python.training.tracking import base as trackable_base from tensorflow.python.training.tracking import tracking as trackable -from tensorflow.python.util import compat +from tensorflow.python.util import compat as compat_util from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export @@ -749,10 +750,16 @@ class TextFileInitializer(TableInitializerBase): with ops.name_scope(self._name, "text_file_init", (table.resource_handle,)): filename = ops.convert_to_tensor( self._filename, dtypes.string, name="asset_filepath") - init_op = gen_lookup_ops.initialize_table_from_text_file_v2( - table.resource_handle, filename, self._key_index, self._value_index, - -1 if self._vocab_size is None else self._vocab_size, self._delimiter, - self._offset) + if self._offset != 0 or compat.forward_compatible(2021, 3, 18): + init_op = gen_lookup_ops.initialize_table_from_text_file_v2( + table.resource_handle, filename, self._key_index, self._value_index, + -1 if self._vocab_size is None else self._vocab_size, + self._delimiter, self._offset) + else: + init_op = gen_lookup_ops.initialize_table_from_text_file_v2( + table.resource_handle, filename, self._key_index, self._value_index, + -1 if self._vocab_size is None else self._vocab_size, + self._delimiter) ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) # If the filename tensor is anything other than a string constant (e.g., # if it is a placeholder) then it does not make sense to track it as an @@ -915,8 +922,8 @@ class StrongHashSpec(HasherSpec): if len(key) != 2: raise ValueError("key must have size 2, got %s." % len(key)) - if not isinstance(key[0], compat.integral_types) or not isinstance( - key[1], compat.integral_types): + if not isinstance(key[0], compat_util.integral_types) or not isinstance( + key[1], compat_util.integral_types): raise TypeError("Invalid key %s. Must be unsigned integer values." % key) return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)