add shape and type check for IteratorGetNextOp and ToSingleElementOp

This commit is contained in:
zilinzhu 2020-06-03 17:13:00 +08:00
parent 6afec5ecbb
commit 98b8320ebd
2 changed files with 44 additions and 4 deletions
tensorflow/core/kernels/data

View File

@ -548,7 +548,10 @@ namespace {
class ToSingleElementOp : public HybridAsyncOpKernel {
public:
explicit ToSingleElementOp(OpKernelConstruction* ctx)
: HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {}
: HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
protected:
Status DoCompute(OpKernelContext* ctx) override {
@ -581,7 +584,20 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
return errors::InvalidArgument("Dataset was empty.");
}
for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
if (components[i].dtype() != output_types_[i]) {
return errors::InvalidArgument(
"The result does not match the expected type for "
"component ",
i, ". Expected: ", DataTypeString(output_types_[i]),
". Actual: ", DataTypeString(components[i].dtype()), ".");
}
if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
return errors::InvalidArgument(
"The result does not match the expected shape "
"for component ",
i, ". Expected: ", output_shapes_[i].DebugString(),
". Actual: ", components[i].shape().DebugString(), ".");
}
ctx->set_output(i, components[i]);
}
@ -593,6 +609,10 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
}
return Status::OK();
}
private:
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
class ReduceDatasetOp : public HybridAsyncOpKernel {
@ -918,7 +938,20 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
return errors::OutOfRange("End of sequence");
}
for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
if (components[i].dtype() != output_types_[i]) {
return errors::InvalidArgument(
"The result does not match the expected type for "
"component ",
i, ". Expected: ", DataTypeString(output_types_[i]),
". Actual: ", DataTypeString(components[i].dtype()), ".");
}
if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
return errors::InvalidArgument(
"The result does not match the expected shape "
"for component ",
i, ". Expected: ", output_shapes_[i].DebugString(),
". Actual: ", components[i].shape().DebugString(), ".");
}
ctx->set_output(i, components[i]);
}
return Status::OK();

View File

@ -216,12 +216,19 @@ class MakeIteratorOp : public HybridAsyncOpKernel {
class IteratorGetNextOp : public HybridAsyncOpKernel {
public:
explicit IteratorGetNextOp(OpKernelConstruction* ctx)
: HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {}
: HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
AsyncOpKernel* AsAsync() override;
protected:
Status DoCompute(OpKernelContext* ctx) override;
private:
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
class DeleteIteratorOp : public HybridAsyncOpKernel {