Modify InitializeTableFromDatasetOp to be async and add the table init op to the table initializers collection

PiperOrigin-RevId: 319190503
Change-Id: Ibd94b6f194e839fc11e37f1153c43533462bc265
This commit is contained in:
A. Unique TensorFlower 2020-07-01 02:49:04 -07:00 committed by TensorFlower Gardener
parent fc49cbb2ad
commit 4fba4cbfdc
6 changed files with 71 additions and 42 deletions

View File

@ -417,6 +417,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/framework:op_requires",
],
)

View File

@ -164,24 +164,29 @@ REGISTER_KERNEL_BUILDER(
Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU),
InitializeTableFromTextFileOp);
class InitializeTableFromDatasetOp : public OpKernel {
class InitializeTableFromDatasetOp : public AsyncOpKernel {
public:
explicit InitializeTableFromDatasetOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
: AsyncOpKernel(ctx),
background_worker_(ctx->env(), "initialize_table_from_dataset") {}
void Compute(OpKernelContext* ctx) override {
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
lookup::InitializableLookupTable* table;
OP_REQUIRES_OK(ctx,
GetInitializableLookupTable("table_handle", ctx, &table));
OP_REQUIRES_OK_ASYNC(
ctx, GetInitializableLookupTable("table_handle", ctx, &table), done);
core::ScopedUnref unref_me(table);
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset));
OP_REQUIRES_OK(ctx,
lookup::InitializeTableFromDataset(ctx, dataset, table));
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done);
background_worker_.Schedule([ctx, dataset, table, done]() {
lookup::InitializeTableFromDataset(ctx, dataset, table, done);
});
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromDatasetOp);
data::BackgroundWorker background_worker_;
};
REGISTER_KERNEL_BUILDER(Name("InitializeTableFromDataset").Device(DEVICE_CPU),

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
@ -451,46 +452,54 @@ class DatasetIterator : public InitializableLookupTable::InitTableIterator {
Status status_;
};
Status InitializeTableFromDataset(OpKernelContext* ctx,
data::DatasetBase* dataset,
InitializableLookupTable* table) {
void InitializeTableFromDataset(OpKernelContext* ctx,
data::DatasetBase* dataset,
InitializableLookupTable* table,
AsyncOpKernel::DoneCallback done) {
// Assert that the dataset types match up to that expected in the table.
const auto& dataset_types = dataset->output_dtypes();
if (dataset_types.size() != 2) {
return errors::InvalidArgument("Dataset should have two output types only");
}
if (dataset_types[0] != table->key_dtype()) {
return errors::InvalidArgument("Key dtype expected: ", table->key_dtype(),
" but obtained: ", dataset_types[0],
" from the dataset");
}
if (dataset_types[1] != table->value_dtype()) {
return errors::InvalidArgument(
"Value dtype expected: ", table->value_dtype(),
" but obtained: ", dataset_types[1], " from the dataset");
}
OP_REQUIRES_ASYNC(
ctx, dataset_types.size() == 2,
errors::InvalidArgument("Dataset should have two output types only"),
done);
OP_REQUIRES_ASYNC(
ctx, dataset_types[0] == table->key_dtype(),
errors::InvalidArgument("Key dtype expected: ", table->key_dtype(),
" but obtained: ", dataset_types[0],
" from the dataset"),
done);
OP_REQUIRES_ASYNC(
ctx, dataset_types[1] == table->value_dtype(),
errors::InvalidArgument("Value dtype expected: ", table->value_dtype(),
" but obtained: ", dataset_types[1],
" from the dataset"),
done);
// Assert that the dataset output shapes are scalars.
const auto& dataset_shapes = dataset->output_shapes();
if (dataset_shapes.size() != 2) {
return errors::InvalidArgument(
"Dataset should have two output shapes only");
}
if (!dataset_shapes[0].IsCompatibleWith(PartialTensorShape({}))) {
return errors::InvalidArgument("Expected scalar for key. Obtained: ",
dataset_shapes[0].DebugString());
}
if (!dataset_shapes[1].IsCompatibleWith(PartialTensorShape({}))) {
return errors::InvalidArgument("Expected scalar for key. Obtained: ",
dataset_shapes[1].DebugString());
}
OP_REQUIRES_ASYNC(
ctx, dataset_shapes.size() == 2,
errors::InvalidArgument("Dataset should have two output shapes only"),
done);
OP_REQUIRES_ASYNC(
ctx, dataset_shapes[0].IsCompatibleWith(PartialTensorShape({})),
errors::InvalidArgument("Expected scalar for key. Obtained: ",
dataset_shapes[0].DebugString()),
done);
OP_REQUIRES_ASYNC(
ctx, dataset_shapes[1].IsCompatibleWith(PartialTensorShape({})),
errors::InvalidArgument("Expected scalar for key. Obtained: ",
dataset_shapes[1].DebugString()),
done);
DatasetIterator iter(dataset);
TF_RETURN_IF_ERROR(iter.Init(ctx));
OP_REQUIRES_OK_ASYNC(ctx, iter.Init(ctx), done);
Status s = table->Initialize(iter);
if (errors::IsFailedPrecondition(s) && table->is_initialized()) {
LOG(INFO) << "Table already initialized from dataset.";
return Status::OK();
done();
return;
}
return s;
ctx->SetStatus(s);
done();
}
} // namespace lookup

View File

@ -58,9 +58,10 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
// Initializes `table` from `dataset` by iterating over it. Caller retains
// ownership of `dataset`.
Status InitializeTableFromDataset(OpKernelContext* ctx,
data::DatasetBase* dataset,
InitializableLookupTable* table);
void InitializeTableFromDataset(OpKernelContext* ctx,
data::DatasetBase* dataset,
InitializableLookupTable* table,
AsyncOpKernel::DoneCallback done);
} // namespace lookup
} // namespace tensorflow

View File

@ -565,6 +565,18 @@ class DatasetInitializerTest(BaseLookupTableTest):
result = self.evaluate(output)
self.assertAllEqual([1, 2, -1], result)
def test_compatibility(self):
with ops.Graph().as_default():
keys = dataset_ops.Dataset.range(100)
values = dataset_ops.Dataset.range(100).map(string_ops.as_string)
ds = dataset_ops.Dataset.zip((keys, values))
init = lookup_ops.DatasetInitializer(ds)
table = self.getHashTable()(init, default_value="")
output = table.lookup(constant_op.constant([0, 2, 5], dtypes.int64))
self.evaluate(lookup_ops.tables_initializer())
result = self.evaluate(output)
self.assertAllEqual(["0", "2", "5"], result)
class InitializeTableFromFileOpTest(BaseLookupTableTest):

View File

@ -468,6 +468,7 @@ class DatasetInitializer(TableInitializerBase):
_check_table_dtypes(table, self._key_dtype, self._value_dtype)
init_op = gen_lookup_ops.initialize_table_from_dataset(
table.resource_handle, self.dataset._variant_tensor) # pylint: disable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op