diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index 15d6438bd02..3161004b7ab 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -455,6 +455,16 @@ Status RegisterCancellationCallback(CancellationManager* cancellation_manager, return Status::OK(); } +Status VerifyTypeMatch(const DataType& expected, const DataType& received, + int index) { + if (expected != received) { + return errors::InvalidArgument("Data type mismatch at component ", index, + ": expected ", DataTypeString(expected), + " but got ", DataTypeString(received), "."); + } + return Status::OK(); +} + Status VerifyTypesMatch(const DataTypeVector& expected, const DataTypeVector& received) { if (expected.size() != received.size()) { @@ -463,12 +473,30 @@ Status VerifyTypesMatch(const DataTypeVector& expected, " types but got ", received.size(), "."); } for (size_t i = 0; i < expected.size(); ++i) { - if (expected[i] != received[i]) { - return errors::InvalidArgument("Data type mismatch at component ", i, - ": expected ", DataTypeString(expected[i]), - " but got ", DataTypeString(received[i]), - "."); - } + TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i)); + } + return Status::OK(); +} + +Status VerifyTypesMatch(const DataTypeVector& expected, + const std::vector<Tensor>& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " types but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i)); + } + return Status::OK(); +} + +Status VerifyShapeCompatible(const PartialTensorShape& expected, + const PartialTensorShape& received, int index) { + if (!expected.IsCompatibleWith(received)) { + return errors::InvalidArgument("Incompatible shapes at component ", index, + ": expected ", expected.DebugString(), + " but got ", received.DebugString(), "."); } return Status::OK(); } @@ -481,12 +509,22 @@ Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, " shapes but got ", received.size(), "."); } for (size_t i = 0; i < expected.size(); ++i) { - if (!expected[i].IsCompatibleWith(received[i])) { - return errors::InvalidArgument("Incompatible shapes at component ", i, - ": expected ", expected[i].DebugString(), - " but got ", received[i].DebugString(), - "."); - } + TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i)); + } + + return Status::OK(); +} + +Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, + const std::vector<Tensor>& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " shapes but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + TF_RETURN_IF_ERROR( + VerifyShapeCompatible(expected[i], received[i].shape(), i)); } return Status::OK(); diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 70ca70176e8..ac087360fd0 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -94,11 +94,17 @@ Status RegisterCancellationCallback(CancellationManager* cancellation_manager, Status VerifyTypesMatch(const DataTypeVector& expected, const DataTypeVector& received); +Status VerifyTypesMatch(const DataTypeVector& expected, + const std::vector<Tensor>& received); + // Returns Status::OK() if `expected` and `received` shapes are compatible, // errors::InvalidArgument otherwise. Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, const std::vector<PartialTensorShape>& received); +Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, + const std::vector<Tensor>& received); + // Returns a stable hash of the subgraph rooted at the given node. // // NOTE: There is currently no guarantee that the hash of a subgraph will stay diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 9fb3c5fb46e..8dd7f4c364b 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -548,7 +548,10 @@ namespace { class ToSingleElementOp : public HybridAsyncOpKernel { public: explicit ToSingleElementOp(OpKernelConstruction* ctx) - : HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {} + : HybridAsyncOpKernel(ctx, "tf_data_to_single_element") { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } protected: Status DoCompute(OpKernelContext* ctx) override { @@ -580,8 +583,9 @@ class ToSingleElementOp : public HybridAsyncOpKernel { if (end_of_sequence) { return errors::InvalidArgument("Dataset was empty."); } + TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components)); + TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components)); for (int i = 0; i < components.size(); ++i) { - // TODO(mrry): Check that the shapes match the shape attrs. ctx->set_output(i, components[i]); } @@ -593,6 +597,10 @@ class ToSingleElementOp : public HybridAsyncOpKernel { } return Status::OK(); } + + private: + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; }; class ReduceDatasetOp : public HybridAsyncOpKernel { @@ -674,33 +682,9 @@ class ReduceDatasetOp : public HybridAsyncOpKernel { std::swap(reduce_func_output, state); } - if (state.size() != output_types_.size()) { - return errors::InvalidArgument( - "The number of result elements does not match " - "the size of output types: ", - state.size(), " vs. ", output_types_.size()); - } - if (state.size() != output_shapes_.size()) { - return errors::InvalidArgument( - "The number of result elements does not match " - "the size of output shapes: ", - state.size(), " vs. ", output_shapes_.size()); - } + TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, state)); + TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, state)); for (size_t i = 0; i < state.size(); ++i) { - if (state[i].dtype() != output_types_[i]) { - return errors::InvalidArgument( - "The result does not match the expected type for " - "component ", - i, ". Expected: ", DataTypeString(output_types_[i]), - ". Actual: ", DataTypeString(state[i].dtype()), "."); - } - if (!output_shapes_[i].IsCompatibleWith(state[i].shape())) { - return errors::InvalidArgument( - "The result does not match the expected shape for " - "component ", - i, ". Expected: ", output_shapes_[i].DebugString(), - ". Actual: ", state[i].shape().DebugString(), "."); - } ctx->set_output(i, state[i]); } return Status::OK(); @@ -917,8 +901,9 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) { if (end_of_sequence) { return errors::OutOfRange("End of sequence"); } + TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components)); + TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components)); for (int i = 0; i < components.size(); ++i) { - // TODO(mrry): Check that the shapes match the shape attrs. ctx->set_output(i, components[i]); } return Status::OK(); diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index 86db80ed75c..938b218bcb7 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -216,12 +216,19 @@ class MakeIteratorOp : public HybridAsyncOpKernel { class IteratorGetNextOp : public HybridAsyncOpKernel { public: explicit IteratorGetNextOp(OpKernelConstruction* ctx) - : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {} + : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } AsyncOpKernel* AsAsync() override; protected: Status DoCompute(OpKernelContext* ctx) override; + + private: + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; }; class DeleteIteratorOp : public HybridAsyncOpKernel {