Modify InitializeTableFromDatasetOp to be async and add the table init op to the table initializers collection
PiperOrigin-RevId: 319190503 Change-Id: Ibd94b6f194e839fc11e37f1153c43533462bc265
This commit is contained in:
parent
fc49cbb2ad
commit
4fba4cbfdc
tensorflow
core/kernels
python
@ -417,6 +417,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/framework:op_requires",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -164,24 +164,29 @@ REGISTER_KERNEL_BUILDER(
|
||||
Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU),
|
||||
InitializeTableFromTextFileOp);
|
||||
|
||||
class InitializeTableFromDatasetOp : public OpKernel {
|
||||
class InitializeTableFromDatasetOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit InitializeTableFromDatasetOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
: AsyncOpKernel(ctx),
|
||||
background_worker_(ctx->env(), "initialize_table_from_dataset") {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
lookup::InitializableLookupTable* table;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInitializableLookupTable("table_handle", ctx, &table));
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, GetInitializableLookupTable("table_handle", ctx, &table), done);
|
||||
core::ScopedUnref unref_me(table);
|
||||
DatasetBase* dataset;
|
||||
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
lookup::InitializeTableFromDataset(ctx, dataset, table));
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done);
|
||||
background_worker_.Schedule([ctx, dataset, table, done]() {
|
||||
lookup::InitializeTableFromDataset(ctx, dataset, table, done);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromDatasetOp);
|
||||
|
||||
data::BackgroundWorker background_worker_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("InitializeTableFromDataset").Device(DEVICE_CPU),
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/function_handle_cache.h"
|
||||
#include "tensorflow/core/framework/op_requires.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -451,46 +452,54 @@ class DatasetIterator : public InitializableLookupTable::InitTableIterator {
|
||||
Status status_;
|
||||
};
|
||||
|
||||
Status InitializeTableFromDataset(OpKernelContext* ctx,
|
||||
data::DatasetBase* dataset,
|
||||
InitializableLookupTable* table) {
|
||||
void InitializeTableFromDataset(OpKernelContext* ctx,
|
||||
data::DatasetBase* dataset,
|
||||
InitializableLookupTable* table,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
// Assert that the dataset types match up to that expected in the table.
|
||||
const auto& dataset_types = dataset->output_dtypes();
|
||||
if (dataset_types.size() != 2) {
|
||||
return errors::InvalidArgument("Dataset should have two output types only");
|
||||
}
|
||||
if (dataset_types[0] != table->key_dtype()) {
|
||||
return errors::InvalidArgument("Key dtype expected: ", table->key_dtype(),
|
||||
" but obtained: ", dataset_types[0],
|
||||
" from the dataset");
|
||||
}
|
||||
if (dataset_types[1] != table->value_dtype()) {
|
||||
return errors::InvalidArgument(
|
||||
"Value dtype expected: ", table->value_dtype(),
|
||||
" but obtained: ", dataset_types[1], " from the dataset");
|
||||
}
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, dataset_types.size() == 2,
|
||||
errors::InvalidArgument("Dataset should have two output types only"),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, dataset_types[0] == table->key_dtype(),
|
||||
errors::InvalidArgument("Key dtype expected: ", table->key_dtype(),
|
||||
" but obtained: ", dataset_types[0],
|
||||
" from the dataset"),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, dataset_types[1] == table->value_dtype(),
|
||||
errors::InvalidArgument("Value dtype expected: ", table->value_dtype(),
|
||||
" but obtained: ", dataset_types[1],
|
||||
" from the dataset"),
|
||||
done);
|
||||
// Assert that the dataset output shapes are scalars.
|
||||
const auto& dataset_shapes = dataset->output_shapes();
|
||||
if (dataset_shapes.size() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Dataset should have two output shapes only");
|
||||
}
|
||||
if (!dataset_shapes[0].IsCompatibleWith(PartialTensorShape({}))) {
|
||||
return errors::InvalidArgument("Expected scalar for key. Obtained: ",
|
||||
dataset_shapes[0].DebugString());
|
||||
}
|
||||
if (!dataset_shapes[1].IsCompatibleWith(PartialTensorShape({}))) {
|
||||
return errors::InvalidArgument("Expected scalar for key. Obtained: ",
|
||||
dataset_shapes[1].DebugString());
|
||||
}
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, dataset_shapes.size() == 2,
|
||||
errors::InvalidArgument("Dataset should have two output shapes only"),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, dataset_shapes[0].IsCompatibleWith(PartialTensorShape({})),
|
||||
errors::InvalidArgument("Expected scalar for key. Obtained: ",
|
||||
dataset_shapes[0].DebugString()),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, dataset_shapes[1].IsCompatibleWith(PartialTensorShape({})),
|
||||
errors::InvalidArgument("Expected scalar for key. Obtained: ",
|
||||
dataset_shapes[1].DebugString()),
|
||||
done);
|
||||
DatasetIterator iter(dataset);
|
||||
TF_RETURN_IF_ERROR(iter.Init(ctx));
|
||||
OP_REQUIRES_OK_ASYNC(ctx, iter.Init(ctx), done);
|
||||
Status s = table->Initialize(iter);
|
||||
if (errors::IsFailedPrecondition(s) && table->is_initialized()) {
|
||||
LOG(INFO) << "Table already initialized from dataset.";
|
||||
return Status::OK();
|
||||
done();
|
||||
return;
|
||||
}
|
||||
return s;
|
||||
ctx->SetStatus(s);
|
||||
done();
|
||||
}
|
||||
|
||||
} // namespace lookup
|
||||
|
@ -58,9 +58,10 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
||||
|
||||
// Initializes `table` from `dataset` by iterating over it. Caller retains
|
||||
// ownership of `dataset`.
|
||||
Status InitializeTableFromDataset(OpKernelContext* ctx,
|
||||
data::DatasetBase* dataset,
|
||||
InitializableLookupTable* table);
|
||||
void InitializeTableFromDataset(OpKernelContext* ctx,
|
||||
data::DatasetBase* dataset,
|
||||
InitializableLookupTable* table,
|
||||
AsyncOpKernel::DoneCallback done);
|
||||
|
||||
} // namespace lookup
|
||||
} // namespace tensorflow
|
||||
|
@ -565,6 +565,18 @@ class DatasetInitializerTest(BaseLookupTableTest):
|
||||
result = self.evaluate(output)
|
||||
self.assertAllEqual([1, 2, -1], result)
|
||||
|
||||
def test_compatibility(self):
|
||||
with ops.Graph().as_default():
|
||||
keys = dataset_ops.Dataset.range(100)
|
||||
values = dataset_ops.Dataset.range(100).map(string_ops.as_string)
|
||||
ds = dataset_ops.Dataset.zip((keys, values))
|
||||
init = lookup_ops.DatasetInitializer(ds)
|
||||
table = self.getHashTable()(init, default_value="")
|
||||
output = table.lookup(constant_op.constant([0, 2, 5], dtypes.int64))
|
||||
self.evaluate(lookup_ops.tables_initializer())
|
||||
result = self.evaluate(output)
|
||||
self.assertAllEqual(["0", "2", "5"], result)
|
||||
|
||||
|
||||
class InitializeTableFromFileOpTest(BaseLookupTableTest):
|
||||
|
||||
|
@ -468,6 +468,7 @@ class DatasetInitializer(TableInitializerBase):
|
||||
_check_table_dtypes(table, self._key_dtype, self._value_dtype)
|
||||
init_op = gen_lookup_ops.initialize_table_from_dataset(
|
||||
table.resource_handle, self.dataset._variant_tensor) # pylint: disable=protected-access
|
||||
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
|
||||
return init_op
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user