From 98b8320ebd0a65c7817ee4145a4ec95da32ce6cb Mon Sep 17 00:00:00 2001
From: zilinzhu <zilinzhu@tencent.com>
Date: Wed, 3 Jun 2020 17:13:00 +0800
Subject: [PATCH 1/2] add shape and type check for IteratorGetNextOp and
 ToSingleElementOp

---
 tensorflow/core/kernels/data/iterator_ops.cc | 39 ++++++++++++++++++--
 tensorflow/core/kernels/data/iterator_ops.h  |  9 ++++-
 2 files changed, 44 insertions(+), 4 deletions(-)

diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 9fb3c5fb46e..7197d622662 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -548,7 +548,10 @@ namespace {
 class ToSingleElementOp : public HybridAsyncOpKernel {
  public:
   explicit ToSingleElementOp(OpKernelConstruction* ctx)
-      : HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {}
+      : HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+  }
 
  protected:
   Status DoCompute(OpKernelContext* ctx) override {
@@ -581,7 +584,20 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
       return errors::InvalidArgument("Dataset was empty.");
     }
     for (int i = 0; i < components.size(); ++i) {
-      // TODO(mrry): Check that the shapes match the shape attrs.
+      if (components[i].dtype() != output_types_[i]) {
+        return errors::InvalidArgument(
+            "The result does not match the expected type for "
+            "component ",
+            i, ". Expected: ", DataTypeString(output_types_[i]),
+            ". Actual: ", DataTypeString(components[i].dtype()), ".");
+      }
+      if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
+        return errors::InvalidArgument(
+            "The result does not match the expected shape "
+            "for component ",
+            i, ". Expected: ", output_shapes_[i].DebugString(),
+            ". Actual: ", components[i].shape().DebugString(), ".");
+      }
       ctx->set_output(i, components[i]);
     }
 
@@ -593,6 +609,10 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
     }
     return Status::OK();
   }
+
+ private:
+  DataTypeVector output_types_;
+  std::vector<PartialTensorShape> output_shapes_;
 };
 
 class ReduceDatasetOp : public HybridAsyncOpKernel {
@@ -918,7 +938,20 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
     return errors::OutOfRange("End of sequence");
   }
   for (int i = 0; i < components.size(); ++i) {
-    // TODO(mrry): Check that the shapes match the shape attrs.
+    if (components[i].dtype() != output_types_[i]) {
+      return errors::InvalidArgument(
+          "The result does not match the expected type for "
+          "component ",
+          i, ". Expected: ", DataTypeString(output_types_[i]),
+          ". Actual: ", DataTypeString(components[i].dtype()), ".");
+    }
+    if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
+      return errors::InvalidArgument(
+          "The result does not match the expected shape "
+          "for component ",
+          i, ". Expected: ", output_shapes_[i].DebugString(),
+          ". Actual: ", components[i].shape().DebugString(), ".");
+    }
     ctx->set_output(i, components[i]);
   }
   return Status::OK();
diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h
index 86db80ed75c..938b218bcb7 100644
--- a/tensorflow/core/kernels/data/iterator_ops.h
+++ b/tensorflow/core/kernels/data/iterator_ops.h
@@ -216,12 +216,19 @@ class MakeIteratorOp : public HybridAsyncOpKernel {
 class IteratorGetNextOp : public HybridAsyncOpKernel {
  public:
   explicit IteratorGetNextOp(OpKernelConstruction* ctx)
-      : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {}
+      : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+  }
 
   AsyncOpKernel* AsAsync() override;
 
  protected:
   Status DoCompute(OpKernelContext* ctx) override;
+
+ private:
+  DataTypeVector output_types_;
+  std::vector<PartialTensorShape> output_shapes_;
 };
 
 class DeleteIteratorOp : public HybridAsyncOpKernel {

From bac4a2b9ba760da19ccc37f235e48bde1f13377c Mon Sep 17 00:00:00 2001
From: zilinzhu <zilinzhu@tencent.com>
Date: Fri, 5 Jun 2020 12:15:48 +0800
Subject: [PATCH 2/2] refactor VerifyTypesMatch and VerifyShapeCompatible

---
 tensorflow/core/kernels/data/dataset_utils.cc | 63 +++++++++++++++----
 tensorflow/core/kernels/data/dataset_utils.h  |  6 ++
 tensorflow/core/kernels/data/iterator_ops.cc  | 60 ++----------------
 3 files changed, 63 insertions(+), 66 deletions(-)

diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index 15d6438bd02..d8d2188fa37 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -455,6 +455,17 @@ Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
   return Status::OK();
 }
 
+Status VerifyTypeMatch(const DataType& expected,
+                       const DataType& received, int index) {
+  if (expected != received) {
+    return errors::InvalidArgument("Data type mismatch at component ", index,
+                                   ": expected ", DataTypeString(expected),
+                                   " but got ", DataTypeString(received),
+                                   ".");
+  }
+  return Status::OK();
+}
+
 Status VerifyTypesMatch(const DataTypeVector& expected,
                         const DataTypeVector& received) {
   if (expected.size() != received.size()) {
@@ -463,12 +474,31 @@ Status VerifyTypesMatch(const DataTypeVector& expected,
         " types but got ", received.size(), ".");
   }
   for (size_t i = 0; i < expected.size(); ++i) {
-    if (expected[i] != received[i]) {
-      return errors::InvalidArgument("Data type mismatch at component ", i,
-                                     ": expected ", DataTypeString(expected[i]),
-                                     " but got ", DataTypeString(received[i]),
-                                     ".");
-    }
+    TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
+  }
+  return Status::OK();
+}
+
+Status VerifyTypesMatch(const DataTypeVector& expected,
+                        const std::vector<Tensor>& received) {
+  if (expected.size() != received.size()) {
+    return errors::InvalidArgument(
+        "Number of components does not match: expected ", expected.size(),
+        " types but got ", received.size(), ".");
+  }
+  for (size_t i = 0; i < expected.size(); ++i) {
+    TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
+  }
+  return Status::OK();
+}
+
+Status VerifyShapeCompatible(const PartialTensorShape& expected,
+                             const PartialTensorShape& received, int index) {
+  if (!expected.IsCompatibleWith(received)) {
+    return errors::InvalidArgument("Incompatible shapes at component ", index,
+                                   ": expected ", expected.DebugString(),
+                                   " but got ", received.DebugString(),
+                                   ".");
   }
   return Status::OK();
 }
@@ -481,12 +511,21 @@ Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
         " shapes but got ", received.size(), ".");
   }
   for (size_t i = 0; i < expected.size(); ++i) {
-    if (!expected[i].IsCompatibleWith(received[i])) {
-      return errors::InvalidArgument("Incompatible shapes at component ", i,
-                                     ": expected ", expected[i].DebugString(),
-                                     " but got ", received[i].DebugString(),
-                                     ".");
-    }
+    TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
+  }
+
+  return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+                              const std::vector<Tensor>& received) {
+  if (expected.size() != received.size()) {
+    return errors::InvalidArgument(
+        "Number of components does not match: expected ", expected.size(),
+        " shapes but got ", received.size(), ".");
+  }
+  for (size_t i = 0; i < expected.size(); ++i) {
+    TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i].shape(), i));
   }
 
   return Status::OK();
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 70ca70176e8..ac087360fd0 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -94,11 +94,17 @@ Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
 Status VerifyTypesMatch(const DataTypeVector& expected,
                         const DataTypeVector& received);
 
+Status VerifyTypesMatch(const DataTypeVector& expected,
+                        const std::vector<Tensor>& received);
+
 // Returns Status::OK() if `expected` and `received` shapes are compatible,
 // errors::InvalidArgument otherwise.
 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
                               const std::vector<PartialTensorShape>& received);
 
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+                              const std::vector<Tensor>& received);
+
 // Returns a stable hash of the subgraph rooted at the given node.
 //
 // NOTE: There is currently no guarantee that the hash of a subgraph will stay
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 7197d622662..8dd7f4c364b 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -583,21 +583,9 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
     if (end_of_sequence) {
       return errors::InvalidArgument("Dataset was empty.");
     }
+    TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
+    TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
     for (int i = 0; i < components.size(); ++i) {
-      if (components[i].dtype() != output_types_[i]) {
-        return errors::InvalidArgument(
-            "The result does not match the expected type for "
-            "component ",
-            i, ". Expected: ", DataTypeString(output_types_[i]),
-            ". Actual: ", DataTypeString(components[i].dtype()), ".");
-      }
-      if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
-        return errors::InvalidArgument(
-            "The result does not match the expected shape "
-            "for component ",
-            i, ". Expected: ", output_shapes_[i].DebugString(),
-            ". Actual: ", components[i].shape().DebugString(), ".");
-      }
       ctx->set_output(i, components[i]);
     }
 
@@ -694,33 +682,9 @@ class ReduceDatasetOp : public HybridAsyncOpKernel {
       std::swap(reduce_func_output, state);
     }
 
-    if (state.size() != output_types_.size()) {
-      return errors::InvalidArgument(
-          "The number of result elements does not match "
-          "the size of output types: ",
-          state.size(), " vs. ", output_types_.size());
-    }
-    if (state.size() != output_shapes_.size()) {
-      return errors::InvalidArgument(
-          "The number of result elements does not match "
-          "the size of output shapes: ",
-          state.size(), " vs. ", output_shapes_.size());
-    }
+    TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, state));
+    TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, state));
     for (size_t i = 0; i < state.size(); ++i) {
-      if (state[i].dtype() != output_types_[i]) {
-        return errors::InvalidArgument(
-            "The result does not match the expected type for "
-            "component ",
-            i, ". Expected: ", DataTypeString(output_types_[i]),
-            ". Actual: ", DataTypeString(state[i].dtype()), ".");
-      }
-      if (!output_shapes_[i].IsCompatibleWith(state[i].shape())) {
-        return errors::InvalidArgument(
-            "The result does not match the expected shape for "
-            "component ",
-            i, ". Expected: ", output_shapes_[i].DebugString(),
-            ". Actual: ", state[i].shape().DebugString(), ".");
-      }
       ctx->set_output(i, state[i]);
     }
     return Status::OK();
@@ -937,21 +901,9 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
   if (end_of_sequence) {
     return errors::OutOfRange("End of sequence");
   }
+  TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
+  TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
   for (int i = 0; i < components.size(); ++i) {
-    if (components[i].dtype() != output_types_[i]) {
-      return errors::InvalidArgument(
-          "The result does not match the expected type for "
-          "component ",
-          i, ". Expected: ", DataTypeString(output_types_[i]),
-          ". Actual: ", DataTypeString(components[i].dtype()), ".");
-    }
-    if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
-      return errors::InvalidArgument(
-          "The result does not match the expected shape "
-          "for component ",
-          i, ". Expected: ", output_shapes_[i].DebugString(),
-          ". Actual: ", components[i].shape().DebugString(), ".");
-    }
     ctx->set_output(i, components[i]);
   }
   return Status::OK();