From 7bbc65be714be8d15378b58d0ea7cbc29e8c3769 Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Wed, 6 Feb 2019 12:39:29 -0800 Subject: [PATCH] Support IteratorGetNextAsOptionalOp in TPU. PiperOrigin-RevId: 232725342 --- tensorflow/compiler/jit/xla_device_ops.h | 2 + tensorflow/core/kernels/data/iterator_ops.cc | 114 ++++++++----------- tensorflow/core/kernels/data/iterator_ops.h | 20 ++++ 3 files changed, 69 insertions(+), 67 deletions(-) diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 927f983ba9e..f201f62a78c 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -241,6 +241,8 @@ class XlaAssignVariableOp : public OpKernel { data::AnonymousIteratorHandleOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ data::IteratorGetNextOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ + data::IteratorGetNextAsOptionalOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ data::IteratorGetNextSyncOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 808f834f62d..0d2dfd962bb 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -967,78 +967,58 @@ void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) { } } -namespace { +void IteratorGetNextAsOptionalOp::ComputeAsync(OpKernelContext* ctx, + DoneCallback done) { + IteratorResource* iterator; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); + // The call to `iterator->GetNext()` may block and depend on an + // inter-op thread pool thread, so we issue the call from the + // owned thread pool. + background_worker_.Schedule(std::bind( + [this, ctx, iterator](DoneCallback done) { + std::vector components; + bool end_of_sequence = false; -class IteratorGetNextAsOptionalOp : public AsyncOpKernel { - public: - explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx) - : AsyncOpKernel(ctx), - background_worker_(ctx->env(), - "tf_data_iterator_get_next_as_optional") { - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - } + Status s = iterator->GetNext(IteratorContext(ctx), &components, + &end_of_sequence); + // NOTE(mrry): We must unref the iterator before calling `done()`, to + // avoid destruction races. + iterator->Unref(); - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - IteratorResource* iterator; - OP_REQUIRES_OK_ASYNC( - ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); - // The call to `iterator->GetNext()` may block and depend on an - // inter-op thread pool thread, so we issue the call from the - // owned thread pool. - background_worker_.Schedule(std::bind( - [this, ctx, iterator](DoneCallback done) { - std::vector components; - bool end_of_sequence = false; - - Status s = iterator->GetNext(IteratorContext(ctx), &components, - &end_of_sequence); - // NOTE(mrry): We must unref the iterator before calling `done()`, to - // avoid destruction races. - iterator->Unref(); - - if (!s.ok()) { - ctx->SetStatus(s); - } else if (end_of_sequence) { - OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done); - } else { - for (int i = 0; i < components.size(); ++i) { - OP_REQUIRES_ASYNC( - ctx, components[i].dtype() == output_types_[i], - errors::InvalidArgument( - "The given optional does not match the expected type for " - "component ", - i, ". Expected: ", DataTypeString(output_types_[i]), - ". Actual: ", DataTypeString(components[i].dtype()), "."), - done); - OP_REQUIRES_ASYNC( - ctx, - output_shapes_[i].IsCompatibleWith(components[i].shape()), - errors::InvalidArgument( - "The given optional does not match the expected shape " - "for component ", - i, ". Expected: ", output_shapes_[i].DebugString(), - ". Actual: ", components[i].shape().DebugString(), "."), - done); - } - - OP_REQUIRES_OK_ASYNC( - ctx, - WriteOptionalWithValueToOutput(ctx, 0, std::move(components)), + if (!s.ok()) { + ctx->SetStatus(s); + } else if (end_of_sequence) { + OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done); + } else { + for (int i = 0; i < components.size(); ++i) { + OP_REQUIRES_ASYNC( + ctx, components[i].dtype() == output_types_[i], + errors::InvalidArgument( + "The given optional does not match the expected type for " + "component ", + i, ". Expected: ", DataTypeString(output_types_[i]), + ". Actual: ", DataTypeString(components[i].dtype()), "."), + done); + OP_REQUIRES_ASYNC( + ctx, output_shapes_[i].IsCompatibleWith(components[i].shape()), + errors::InvalidArgument( + "The given optional does not match the expected shape " + "for component ", + i, ". Expected: ", output_shapes_[i].DebugString(), + ". Actual: ", components[i].shape().DebugString(), "."), done); } - done(); - }, - std::move(done))); - } - private: - BackgroundWorker background_worker_; - DataTypeVector output_types_; - std::vector output_shapes_; -}; - -} // namespace + OP_REQUIRES_OK_ASYNC( + ctx, + WriteOptionalWithValueToOutput(ctx, 0, std::move(components)), + done); + } + done(); + }, + std::move(done))); +} void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) { const Tensor& resource_handle_t = ctx->input(0); diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index cd722698590..7d769d365e9 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/ops_util.h" namespace tensorflow { @@ -115,6 +117,24 @@ class IteratorGetNextOp : public AsyncOpKernel { BackgroundWorker background_worker_; }; +class IteratorGetNextAsOptionalOp : public AsyncOpKernel { + public: + explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + background_worker_(ctx->env(), + "tf_data_iterator_get_next_as_optional") { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + private: + BackgroundWorker background_worker_; + DataTypeVector output_types_; + std::vector output_shapes_; +}; + class IteratorGetNextSyncOp : public OpKernel { public: explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}