Support IteratorGetNextAsOptionalOp in TPU.

PiperOrigin-RevId: 232725342
This commit is contained in:
Ruoxin Sang 2019-02-06 12:39:29 -08:00 committed by TensorFlower Gardener
parent 3f8dcd3e28
commit 7bbc65be71
3 changed files with 69 additions and 67 deletions

View File

@ -241,6 +241,8 @@ class XlaAssignVariableOp : public OpKernel {
data::AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
data::IteratorGetNextOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
data::IteratorGetNextAsOptionalOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \
data::IteratorGetNextSyncOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \

View File

@ -967,78 +967,58 @@ void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) {
}
}
namespace {
void IteratorGetNextAsOptionalOp::ComputeAsync(OpKernelContext* ctx,
DoneCallback done) {
IteratorResource* iterator;
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
background_worker_.Schedule(std::bind(
[this, ctx, iterator](DoneCallback done) {
std::vector<Tensor> components;
bool end_of_sequence = false;
class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
public:
explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
background_worker_(ctx->env(),
"tf_data_iterator_get_next_as_optional") {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
Status s = iterator->GetNext(IteratorContext(ctx), &components,
&end_of_sequence);
// NOTE(mrry): We must unref the iterator before calling `done()`, to
// avoid destruction races.
iterator->Unref();
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
IteratorResource* iterator;
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
background_worker_.Schedule(std::bind(
[this, ctx, iterator](DoneCallback done) {
std::vector<Tensor> components;
bool end_of_sequence = false;
Status s = iterator->GetNext(IteratorContext(ctx), &components,
&end_of_sequence);
// NOTE(mrry): We must unref the iterator before calling `done()`, to
// avoid destruction races.
iterator->Unref();
if (!s.ok()) {
ctx->SetStatus(s);
} else if (end_of_sequence) {
OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done);
} else {
for (int i = 0; i < components.size(); ++i) {
OP_REQUIRES_ASYNC(
ctx, components[i].dtype() == output_types_[i],
errors::InvalidArgument(
"The given optional does not match the expected type for "
"component ",
i, ". Expected: ", DataTypeString(output_types_[i]),
". Actual: ", DataTypeString(components[i].dtype()), "."),
done);
OP_REQUIRES_ASYNC(
ctx,
output_shapes_[i].IsCompatibleWith(components[i].shape()),
errors::InvalidArgument(
"The given optional does not match the expected shape "
"for component ",
i, ". Expected: ", output_shapes_[i].DebugString(),
". Actual: ", components[i].shape().DebugString(), "."),
done);
}
OP_REQUIRES_OK_ASYNC(
ctx,
WriteOptionalWithValueToOutput(ctx, 0, std::move(components)),
if (!s.ok()) {
ctx->SetStatus(s);
} else if (end_of_sequence) {
OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done);
} else {
for (int i = 0; i < components.size(); ++i) {
OP_REQUIRES_ASYNC(
ctx, components[i].dtype() == output_types_[i],
errors::InvalidArgument(
"The given optional does not match the expected type for "
"component ",
i, ". Expected: ", DataTypeString(output_types_[i]),
". Actual: ", DataTypeString(components[i].dtype()), "."),
done);
OP_REQUIRES_ASYNC(
ctx, output_shapes_[i].IsCompatibleWith(components[i].shape()),
errors::InvalidArgument(
"The given optional does not match the expected shape "
"for component ",
i, ". Expected: ", output_shapes_[i].DebugString(),
". Actual: ", components[i].shape().DebugString(), "."),
done);
}
done();
},
std::move(done)));
}
private:
BackgroundWorker background_worker_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
} // namespace
OP_REQUIRES_OK_ASYNC(
ctx,
WriteOptionalWithValueToOutput(ctx, 0, std::move(components)),
done);
}
done();
},
std::move(done)));
}
void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
const Tensor& resource_handle_t = ctx->input(0);

View File

@ -19,6 +19,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_util.h"
namespace tensorflow {
@ -115,6 +117,24 @@ class IteratorGetNextOp : public AsyncOpKernel {
BackgroundWorker background_worker_;
};
class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
public:
explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
background_worker_(ctx->env(),
"tf_data_iterator_get_next_as_optional") {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
private:
BackgroundWorker background_worker_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
class IteratorGetNextSyncOp : public OpKernel {
public:
explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}