Support IteratorGetNextAsOptionalOp in TPU.
PiperOrigin-RevId: 232725342
This commit is contained in:
parent
3f8dcd3e28
commit
7bbc65be71
@ -241,6 +241,8 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
data::AnonymousIteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
|
||||
data::IteratorGetNextOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
|
||||
data::IteratorGetNextAsOptionalOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \
|
||||
data::IteratorGetNextSyncOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \
|
||||
|
@ -967,78 +967,58 @@ void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void IteratorGetNextAsOptionalOp::ComputeAsync(OpKernelContext* ctx,
|
||||
DoneCallback done) {
|
||||
IteratorResource* iterator;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
|
||||
// The call to `iterator->GetNext()` may block and depend on an
|
||||
// inter-op thread pool thread, so we issue the call from the
|
||||
// owned thread pool.
|
||||
background_worker_.Schedule(std::bind(
|
||||
[this, ctx, iterator](DoneCallback done) {
|
||||
std::vector<Tensor> components;
|
||||
bool end_of_sequence = false;
|
||||
|
||||
class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
|
||||
: AsyncOpKernel(ctx),
|
||||
background_worker_(ctx->env(),
|
||||
"tf_data_iterator_get_next_as_optional") {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
}
|
||||
Status s = iterator->GetNext(IteratorContext(ctx), &components,
|
||||
&end_of_sequence);
|
||||
// NOTE(mrry): We must unref the iterator before calling `done()`, to
|
||||
// avoid destruction races.
|
||||
iterator->Unref();
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
IteratorResource* iterator;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
|
||||
// The call to `iterator->GetNext()` may block and depend on an
|
||||
// inter-op thread pool thread, so we issue the call from the
|
||||
// owned thread pool.
|
||||
background_worker_.Schedule(std::bind(
|
||||
[this, ctx, iterator](DoneCallback done) {
|
||||
std::vector<Tensor> components;
|
||||
bool end_of_sequence = false;
|
||||
|
||||
Status s = iterator->GetNext(IteratorContext(ctx), &components,
|
||||
&end_of_sequence);
|
||||
// NOTE(mrry): We must unref the iterator before calling `done()`, to
|
||||
// avoid destruction races.
|
||||
iterator->Unref();
|
||||
|
||||
if (!s.ok()) {
|
||||
ctx->SetStatus(s);
|
||||
} else if (end_of_sequence) {
|
||||
OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done);
|
||||
} else {
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, components[i].dtype() == output_types_[i],
|
||||
errors::InvalidArgument(
|
||||
"The given optional does not match the expected type for "
|
||||
"component ",
|
||||
i, ". Expected: ", DataTypeString(output_types_[i]),
|
||||
". Actual: ", DataTypeString(components[i].dtype()), "."),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx,
|
||||
output_shapes_[i].IsCompatibleWith(components[i].shape()),
|
||||
errors::InvalidArgument(
|
||||
"The given optional does not match the expected shape "
|
||||
"for component ",
|
||||
i, ". Expected: ", output_shapes_[i].DebugString(),
|
||||
". Actual: ", components[i].shape().DebugString(), "."),
|
||||
done);
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
WriteOptionalWithValueToOutput(ctx, 0, std::move(components)),
|
||||
if (!s.ok()) {
|
||||
ctx->SetStatus(s);
|
||||
} else if (end_of_sequence) {
|
||||
OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done);
|
||||
} else {
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, components[i].dtype() == output_types_[i],
|
||||
errors::InvalidArgument(
|
||||
"The given optional does not match the expected type for "
|
||||
"component ",
|
||||
i, ". Expected: ", DataTypeString(output_types_[i]),
|
||||
". Actual: ", DataTypeString(components[i].dtype()), "."),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, output_shapes_[i].IsCompatibleWith(components[i].shape()),
|
||||
errors::InvalidArgument(
|
||||
"The given optional does not match the expected shape "
|
||||
"for component ",
|
||||
i, ". Expected: ", output_shapes_[i].DebugString(),
|
||||
". Actual: ", components[i].shape().DebugString(), "."),
|
||||
done);
|
||||
}
|
||||
done();
|
||||
},
|
||||
std::move(done)));
|
||||
}
|
||||
|
||||
private:
|
||||
BackgroundWorker background_worker_;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
WriteOptionalWithValueToOutput(ctx, 0, std::move(components)),
|
||||
done);
|
||||
}
|
||||
done();
|
||||
},
|
||||
std::move(done)));
|
||||
}
|
||||
|
||||
void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
|
||||
const Tensor& resource_handle_t = ctx->input(0);
|
||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -115,6 +117,24 @@ class IteratorGetNextOp : public AsyncOpKernel {
|
||||
BackgroundWorker background_worker_;
|
||||
};
|
||||
|
||||
class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
|
||||
: AsyncOpKernel(ctx),
|
||||
background_worker_(ctx->env(),
|
||||
"tf_data_iterator_get_next_as_optional") {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
|
||||
|
||||
private:
|
||||
BackgroundWorker background_worker_;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
class IteratorGetNextSyncOp : public OpKernel {
|
||||
public:
|
||||
explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
Loading…
Reference in New Issue
Block a user