add shape and type check for IteratorGetNextOp and ToSingleElementOp
This commit is contained in:
parent
6afec5ecbb
commit
98b8320ebd
tensorflow/core/kernels/data
@ -548,7 +548,10 @@ namespace {
|
||||
class ToSingleElementOp : public HybridAsyncOpKernel {
|
||||
public:
|
||||
explicit ToSingleElementOp(OpKernelConstruction* ctx)
|
||||
: HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {}
|
||||
: HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
}
|
||||
|
||||
protected:
|
||||
Status DoCompute(OpKernelContext* ctx) override {
|
||||
@ -581,7 +584,20 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
|
||||
return errors::InvalidArgument("Dataset was empty.");
|
||||
}
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
||||
if (components[i].dtype() != output_types_[i]) {
|
||||
return errors::InvalidArgument(
|
||||
"The result does not match the expected type for "
|
||||
"component ",
|
||||
i, ". Expected: ", DataTypeString(output_types_[i]),
|
||||
". Actual: ", DataTypeString(components[i].dtype()), ".");
|
||||
}
|
||||
if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"The result does not match the expected shape "
|
||||
"for component ",
|
||||
i, ". Expected: ", output_shapes_[i].DebugString(),
|
||||
". Actual: ", components[i].shape().DebugString(), ".");
|
||||
}
|
||||
ctx->set_output(i, components[i]);
|
||||
}
|
||||
|
||||
@ -593,6 +609,10 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
class ReduceDatasetOp : public HybridAsyncOpKernel {
|
||||
@ -918,7 +938,20 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
|
||||
return errors::OutOfRange("End of sequence");
|
||||
}
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
||||
if (components[i].dtype() != output_types_[i]) {
|
||||
return errors::InvalidArgument(
|
||||
"The result does not match the expected type for "
|
||||
"component ",
|
||||
i, ". Expected: ", DataTypeString(output_types_[i]),
|
||||
". Actual: ", DataTypeString(components[i].dtype()), ".");
|
||||
}
|
||||
if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"The result does not match the expected shape "
|
||||
"for component ",
|
||||
i, ". Expected: ", output_shapes_[i].DebugString(),
|
||||
". Actual: ", components[i].shape().DebugString(), ".");
|
||||
}
|
||||
ctx->set_output(i, components[i]);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -216,12 +216,19 @@ class MakeIteratorOp : public HybridAsyncOpKernel {
|
||||
class IteratorGetNextOp : public HybridAsyncOpKernel {
|
||||
public:
|
||||
explicit IteratorGetNextOp(OpKernelConstruction* ctx)
|
||||
: HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {}
|
||||
: HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
}
|
||||
|
||||
AsyncOpKernel* AsAsync() override;
|
||||
|
||||
protected:
|
||||
Status DoCompute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
class DeleteIteratorOp : public HybridAsyncOpKernel {
|
||||
|
Loading…
Reference in New Issue
Block a user