From 1d3326cb6696e760e4b87a04668eb522d19673c4 Mon Sep 17 00:00:00 2001
From: Fei Hu <hufei68@gmail.com>
Date: Wed, 4 Sep 2019 15:50:03 -0700
Subject: [PATCH] Address the comments

---
 tensorflow/core/framework/function_testlib.h  | 58 +++++++++----------
 .../kernels/data/batch_dataset_op_test.cc     | 18 +++---
 .../core/kernels/data/dataset_test_base.cc    | 31 +++++-----
 .../core/kernels/data/dataset_test_base.h     | 30 ++++++----
 4 files changed, 71 insertions(+), 66 deletions(-)

diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
index eb5df661e56..06b9d48c289 100644
--- a/tensorflow/core/framework/function_testlib.h
+++ b/tensorflow/core/framework/function_testlib.h
@@ -68,16 +68,16 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
 // For testing convenience, we provide a few simple functions that can
 // be easily executed and tested.
 
-// x:T -> x * 2.
+// x: T -> x * 2.
 FunctionDef XTimesTwo();
 
-// x:T -> cpu(x * 2) + cpu(x * 3).
+// x: T -> cpu(x * 2) + cpu(x * 3).
 FunctionDef TwoDeviceTimesFive();
 
-// x:T -> cpu(x * 2), gpu(x * 3).
+// x: T -> cpu(x * 2), gpu(x * 3).
 FunctionDef TwoDeviceMult();
 
-// cpu(x):T, gpu(y):T -> cpu(x * 2), gpu(y * 3).
+// cpu(x): T, gpu(y): T -> cpu(x * 2), gpu(y * 3).
 FunctionDef TwoDeviceInputOutput();
 
 // Function taking a list of Tensors as input.
@@ -86,25 +86,25 @@ FunctionDef FuncWithListInput();
 // Function returning a list of Tensors as output.
 FunctionDef FuncWithListOutput();
 
-// x:T -> x + x.
+// x: T -> x + x.
 FunctionDef XAddX();
 
-// x: T, y:T -> x + y.
+// x: T, y: T -> x + y.
 FunctionDef XAddY();
 
-// x:T -> x * 2, where x is int32.
+// x: T -> x * 2, where x is int32.
 FunctionDef XTimesTwoInt32();
 
-// x:T -> (x * 2) * 2.
+// x: T -> (x * 2) * 2.
 FunctionDef XTimesFour();
 
-// x:T -> ((x * 2) * 2) * 2.
+// x: T -> ((x * 2) * 2) * 2.
 FunctionDef XTimes16();
 
-// w:T, x:T, b:T -> MatMul(w, x) + b
+// w: T, x: T, b: T -> MatMul(w, x) + b
 FunctionDef WXPlusB();
 
-// x:T -> x:T, T is a type which we automatically converts to a bool.
+// x: T -> x: T, T is a type which we automatically converts to a bool.
 FunctionDef NonZero();
 
 // x: T -> bool.
@@ -113,56 +113,56 @@ FunctionDef IsZero();
 // x: T -> int64
 FunctionDef RandomUniform();
 
-// x:T, y:T -> y:T, x:T
+// x: T, y:T  -> y: T, x: T
 FunctionDef Swap();
 
-// x:T, y:T -> y:T, x:T, the body has no nodes.
+// x: T, y: T -> y: T, x: T, the body has no nodes.
 FunctionDef EmptyBodySwap();
 
-// x:float, y:resource -> y:resource, 2*x:float.
+// x: float, y: resource -> y: resource, 2*x: float.
 FunctionDef ResourceOutput();
 
-// x:resource -> x:resource
+// x: resource -> x: resource
 FunctionDef ResourceIdentity();
 
-// x:resource -> y:float.
+// x: resource -> y: float.
 FunctionDef ReadResourceVariable();
 
 // Contains malformed control flow which can't be run by the executor.
 FunctionDef InvalidControlFlow();
 
-// x:T -> x <= N.
+// x: T -> x <= N.
 FunctionDef LessThanOrEqualToN(int64 N);
 
-// x:T, y:T -> x+1, x*y
+// x: T, y: T -> x + 1, x * y
 FunctionDef XPlusOneXTimesY();
 
-// x:T, y:T -> x <= N
+// x: T, y: T -> x <= N
 FunctionDef XYXLessThanOrEqualToN(int64 N);
 
 // x: T -> bool
 FunctionDef RandomUniformLess();
 
-// start:int64, stop:int64, step:int64 -> y:RangeDatasetOp::Dataset
+// start: int64, stop: int64, step: int64 -> y: RangeDatasetOp::Dataset
 FunctionDef MakeRangeDataset();
 
-// input_dataset:variant, batch_size:int64, drop_remainder:bool
-// -> y:BatchDatasetV2::Dataset
+// input_dataset: variant, batch_size: int64, drop_remainder: bool
+// -> y: BatchDatasetV2::Dataset
 FunctionDef MakeBatchDataset();
 
-// input_dataset:variant, other_arguments:Targuments, f:func,
-// Targuments:list(type), output_types:list(type), output_shapes:list(shape),
-// use_inter_op_parallelism:bool, preserve_cardinality:bool
-// -> y:MapDatasetOp::Dataset
+// input_dataset: variant, other_arguments: Targuments, f: func,
+// Targuments: list(type), output_types: list(type), output_shapes: list(shape),
+// use_inter_op_parallelism: bool, preserve_cardinality: bool
+// -> y: MapDatasetOp::Dataset
 FunctionDef MakeMapDataset(bool has_other_args);
 
-// input_dataset:variant, count:int64 -> y:TakeDataset::Dataset
+// input_dataset: variant, count: int64 -> y: TakeDataset::Dataset
 FunctionDef MakeTakeDataset();
 
-// x:T -> y:TensorSliceDatasetOp::Dataset
+// x: T -> y: TensorSliceDatasetOp::Dataset
 FunctionDef MakeTensorSliceDataset();
 
-// x:T -> y:T, idx:out_idx
+// x: T -> y: T, idx: out_idx
 FunctionDef Unique();
 
 void FunctionTestSchedClosure(std::function<void()> fn);
diff --git a/tensorflow/core/kernels/data/batch_dataset_op_test.cc b/tensorflow/core/kernels/data/batch_dataset_op_test.cc
index ecffd7a46bc..2ba4d777c14 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op_test.cc
@@ -17,7 +17,7 @@ namespace tensorflow {
 namespace data {
 namespace {
 
-constexpr char kNodeName[] = "batch_dataset_v2";
+constexpr char kNodeName[] = "batch_dataset";
 
 class BatchDatasetOpTest : public DatasetOpsTestBaseV2 {};
 
@@ -30,7 +30,7 @@ BatchDatasetParams BatchDatasetParams1() {
                             /*parallel_copy=*/true,
                             /*output_dtypes=*/{DT_INT64},
                             /*output_shapes=*/{PartialTensorShape({4})},
-                            /*node_name=*/"batch_dataset_v2");
+                            /*node_name=*/"batch_dataset");
 }
 
 // Test Case 2: test BatchDatasetV2 with `drop_remainder` = true and a batch
@@ -42,7 +42,7 @@ BatchDatasetParams BatchDatasetParams2() {
                             /*parallel_copy=*/false,
                             /*output_dtypes=*/{DT_INT64},
                             /*output_shapes=*/{PartialTensorShape({4})},
-                            /*node_name=*/"batch_dataset_v2");
+                            /*node_name=*/"batch_dataset");
 }
 
 // Test Case 3: test BatchDatasetV2 with `drop_remainder` = false and a batch
@@ -54,7 +54,7 @@ BatchDatasetParams BatchDatasetParams3() {
                             /*parallel_copy=*/false,
                             /*output_dtypes=*/{DT_INT64},
                             /*output_shapes=*/{PartialTensorShape({-1})},
-                            /*node_name=*/"batch_dataset_0");
+                            /*node_name=*/"batch_dataset");
 }
 
 // Test Case 4: test BatchDatasetV2 with `drop_remainder` = true and a batch
@@ -66,7 +66,7 @@ BatchDatasetParams BatchDatasetParams4() {
                             /*parallel_copy=*/true,
                             /*output_dtypes=*/{DT_INT64},
                             /*output_shapes=*/{PartialTensorShape({3})},
-                            /*node_name=*/"batch_dataset_v2");
+                            /*node_name=*/"batch_dataset");
 }
 
 // Test Case 5: test BatchDatasetV2 with `drop_remainder` = true and
@@ -78,7 +78,7 @@ BatchDatasetParams BatchDatasetParams5() {
                             /*parallel_copy=*/true,
                             /*output_dtypes=*/{DT_INT64},
                             /*output_shapes=*/{PartialTensorShape({12})},
-                            /*node_name=*/kNodeName);
+                            /*node_name=*/"batch_dataset");
 }
 
 // Test Case 6: test BatchDatasetV2 with `drop_remainder` = false and
@@ -90,7 +90,7 @@ BatchDatasetParams BatchDatasetParams6() {
                             /*parallel_copy=*/true,
                             /*output_dtypes=*/{DT_INT64},
                             /*output_shapes=*/{PartialTensorShape({-1})},
-                            /*node_name=*/"batch_dataset_v2");
+                            /*node_name=*/"batch_dataset");
 }
 
 // Test Case 7: test BatchDatasetV2 with `drop_remainder` = false and
@@ -102,7 +102,7 @@ BatchDatasetParams BatchDatasetParams7() {
                             /*parallel_copy=*/false,
                             /*output_dtypes=*/{DT_INT64},
                             /*output_shapes=*/{PartialTensorShape({4})},
-                            /*node_name=*/"batch_dataset_v2");
+                            /*node_name=*/"batch_dataset");
 }
 
 // Test Case 8: test BatchDatasetV2 with an invalid batch size
@@ -113,7 +113,7 @@ BatchDatasetParams InvalidBatchSizeBatchDatasetParams() {
                             /*parallel_copy=*/false,
                             /*output_dtypes=*/{DT_INT64},
                             /*output_shapes=*/{PartialTensorShape({3})},
-                            /*node_name=*/"batch_dataset_v2");
+                            /*node_name=*/"batch_dataset");
 }
 
 std::vector<GetNextTestCase<BatchDatasetParams>> GetNextTestCases() {
diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc
index a4c00b6f8db..fb5c092812e 100644
--- a/tensorflow/core/kernels/data/dataset_test_base.cc
+++ b/tensorflow/core/kernels/data/dataset_test_base.cc
@@ -688,7 +688,7 @@ Status DatasetOpsTestBaseV2::Initialize(DatasetParams& dataset_params) {
     TF_RETURN_IF_ERROR(MakeDatasetTensor(pair.first.get(), &pair.second));
   }
   gtl::InlinedVector<TensorValue, 4> inputs;
-  TF_RETURN_IF_ERROR(dataset_params.MakeInputs(&inputs));
+  TF_RETURN_IF_ERROR(dataset_params.GetInputs(&inputs));
   TF_RETURN_IF_ERROR(
       CreateDatasetContext(dataset_kernel_.get(), &inputs, &dataset_ctx_));
   TF_RETURN_IF_ERROR(
@@ -705,9 +705,9 @@ Status DatasetOpsTestBaseV2::MakeDatasetOpKernel(
   name_utils::OpNameParams params;
   params.op_version = dataset_params.op_version();
   std::vector<string> input_placeholder;
-  TF_RETURN_IF_ERROR(dataset_params.MakeInputPlaceholder(&input_placeholder));
+  TF_RETURN_IF_ERROR(dataset_params.GetInputPlaceholder(&input_placeholder));
   AttributeVector attributes;
-  TF_RETURN_IF_ERROR(dataset_params.MakeAttributes(&attributes));
+  TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
   NodeDef node_def = test::function::NDef(
       dataset_params.node_name(),
       name_utils::OpName(ToString(dataset_params.type()), params),
@@ -724,9 +724,9 @@ Status DatasetOpsTestBaseV2::MakeDatasetTensor(DatasetParams* dataset_params,
   }
 
   AttributeVector attributes;
-  TF_RETURN_IF_ERROR(dataset_params->MakeAttributes(&attributes));
+  TF_RETURN_IF_ERROR(dataset_params->GetAttributes(&attributes));
   gtl::InlinedVector<TensorValue, 4> inputs;
-  TF_RETURN_IF_ERROR(dataset_params->MakeInputs(&inputs));
+  TF_RETURN_IF_ERROR(dataset_params->GetInputs(&inputs));
   std::vector<Tensor> input_tensors;
   for (auto& tensor_value : inputs) {
     input_tensors.emplace_back(*tensor_value.tensor);
@@ -756,7 +756,7 @@ Status DatasetOpsTestBaseV2::MakeDatasetTensorFunc(
     case DatasetParamsType::Map: {
       std::vector<string> input_placeholder;
       TF_RETURN_IF_ERROR(
-          dataset_params.MakeInputPlaceholder(&input_placeholder));
+          dataset_params.GetInputPlaceholder(&input_placeholder));
       bool has_other_args = input_placeholder.size() > 1;
       *fdef = test::function::MakeMapDataset(has_other_args);
       break;
@@ -808,26 +808,26 @@ RangeDatasetParams::RangeDatasetParams(int64 start, int64 stop, int64 step)
       stop_(CreateTensor<int64>(TensorShape({}), {stop})),
       step_(CreateTensor<int64>(TensorShape({}), {step})) {}
 
-Status RangeDatasetParams::MakeInputs(
+Status RangeDatasetParams::GetInputs(
     gtl::InlinedVector<TensorValue, 4>* inputs) {
   *inputs = {TensorValue(&start_), TensorValue(&stop_), TensorValue(&step_)};
   return Status::OK();
 }
 
-Status RangeDatasetParams::MakeInputPlaceholder(
+Status RangeDatasetParams::GetInputPlaceholder(
     std::vector<string>* input_placeholder) const {
   *input_placeholder = {RangeDatasetOp::kStart, RangeDatasetOp::kStop,
                         RangeDatasetOp::kStep};
   return Status::OK();
 }
 
-Status RangeDatasetParams::MakeAttributes(AttributeVector* attr_vector) const {
+Status RangeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
   *attr_vector = {{RangeDatasetOp::kOutputTypes, output_dtypes_},
                   {RangeDatasetOp::kOutputShapes, output_shapes_}};
   return Status::OK();
 }
 
-Status BatchDatasetParams::MakeInputs(
+Status BatchDatasetParams::GetInputs(
     gtl::InlinedVector<TensorValue, 4>* inputs) {
   inputs->reserve(input_dataset_params_group_.size());
   for (auto& pair : input_dataset_params_group_) {
@@ -844,7 +844,7 @@ Status BatchDatasetParams::MakeInputs(
   return Status::OK();
 }
 
-Status BatchDatasetParams::MakeInputPlaceholder(
+Status BatchDatasetParams::GetInputPlaceholder(
     std::vector<string>* input_placeholder) const {
   *input_placeholder = {BatchDatasetOp::kInputDataset,
                         BatchDatasetOp::kBatchSize,
@@ -852,7 +852,7 @@ Status BatchDatasetParams::MakeInputPlaceholder(
   return Status::OK();
 }
 
-Status BatchDatasetParams::MakeAttributes(AttributeVector* attr_vector) const {
+Status BatchDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
   *attr_vector = {{BatchDatasetOp::kParallelCopy, parallel_copy_},
                   {BatchDatasetOp::kOutputTypes, output_dtypes_},
                   {BatchDatasetOp::kOutputShapes, output_shapes_}};
@@ -861,8 +861,7 @@ Status BatchDatasetParams::MakeAttributes(AttributeVector* attr_vector) const {
 
 int BatchDatasetParams::op_version() const { return op_version_; }
 
-Status MapDatasetParams::MakeInputs(
-    gtl::InlinedVector<TensorValue, 4>* inputs) {
+Status MapDatasetParams::GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) {
   inputs->reserve(input_dataset_params_group_.size());
   for (auto& pair : input_dataset_params_group_) {
     if (!IsDatasetTensor(pair.second)) {
@@ -879,7 +878,7 @@ Status MapDatasetParams::MakeInputs(
   return Status::OK();
 }
 
-Status MapDatasetParams::MakeInputPlaceholder(
+Status MapDatasetParams::GetInputPlaceholder(
     std::vector<string>* input_placeholder) const {
   input_placeholder->emplace_back(MapDatasetOp::kInputDataset);
   for (int i = 0; i < other_arguments_.size(); ++i) {
@@ -889,7 +888,7 @@ Status MapDatasetParams::MakeInputPlaceholder(
   return Status::OK();
 }
 
-Status MapDatasetParams::MakeAttributes(AttributeVector* attr_vector) const {
+Status MapDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
   *attr_vector = {
       {MapDatasetOp::kFunc, func_},
       {MapDatasetOp::kTarguments, type_arguments_},
diff --git a/tensorflow/core/kernels/data/dataset_test_base.h b/tensorflow/core/kernels/data/dataset_test_base.h
index b8487744621..0f5257256d6 100644
--- a/tensorflow/core/kernels/data/dataset_test_base.h
+++ b/tensorflow/core/kernels/data/dataset_test_base.h
@@ -130,14 +130,14 @@ class DatasetParams {
   ~DatasetParams() {}
 
   // Returns the dataset input values as a TensorValue vector.
-  virtual Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) = 0;
+  virtual Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) = 0;
 
   // Returns the dataset input names as a string vector.
-  virtual Status MakeInputPlaceholder(
+  virtual Status GetInputPlaceholder(
       std::vector<string>* input_placeholder) const = 0;
 
   // Returns the dataset attributes as a vector.
-  virtual Status MakeAttributes(AttributeVector* attributes) const = 0;
+  virtual Status GetAttributes(AttributeVector* attributes) const = 0;
 
   // Checks if the tensor is a dataset variant tensor.
   static bool IsDatasetTensor(const Tensor& tensor);
@@ -177,6 +177,8 @@ class DatasetParams {
   int op_version_ = 1;
 };
 
+// `RangeDatasetParams` is a common dataset parameter type that are used in
+// testing.
 class RangeDatasetParams : public DatasetParams {
  public:
   RangeDatasetParams(int64 start, int64 stop, int64 step,
@@ -186,12 +188,12 @@ class RangeDatasetParams : public DatasetParams {
 
   RangeDatasetParams(int64 start, int64 stop, int64 step);
 
-  Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
+  Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
 
-  Status MakeInputPlaceholder(
+  Status GetInputPlaceholder(
       std::vector<string>* input_placeholder) const override;
 
-  Status MakeAttributes(AttributeVector* attr_vector) const override;
+  Status GetAttributes(AttributeVector* attr_vector) const override;
 
  private:
   Tensor start_;
@@ -199,6 +201,8 @@ class RangeDatasetParams : public DatasetParams {
   Tensor step_;
 };
 
+// `BatchDatasetParams` is a common dataset parameter type that are used in
+// testing.
 class BatchDatasetParams : public DatasetParams {
  public:
   template <typename T>
@@ -218,12 +222,12 @@ class BatchDatasetParams : public DatasetParams {
         std::make_pair(std::move(input_dataset_params_ptr), Tensor()));
   }
 
-  Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
+  Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
 
-  Status MakeInputPlaceholder(
+  Status GetInputPlaceholder(
       std::vector<string>* input_placeholder) const override;
 
-  Status MakeAttributes(AttributeVector* attr_vector) const override;
+  Status GetAttributes(AttributeVector* attr_vector) const override;
 
   int op_version() const override;
 
@@ -234,6 +238,8 @@ class BatchDatasetParams : public DatasetParams {
   int op_version_ = 2;
 };
 
+// `MapDatasetParams` is a common dataset parameter type that are used in
+// testing.
 class MapDatasetParams : public DatasetParams {
  public:
   template <typename T>
@@ -258,12 +264,12 @@ class MapDatasetParams : public DatasetParams {
         std::make_pair(std::move(input_dataset_params_ptr), Tensor()));
   }
 
-  Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
+  Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
 
-  Status MakeInputPlaceholder(
+  Status GetInputPlaceholder(
       std::vector<string>* input_placeholder) const override;
 
-  Status MakeAttributes(AttributeVector* attr_vector) const override;
+  Status GetAttributes(AttributeVector* attr_vector) const override;
 
   std::vector<FunctionDef> func_lib() const override;