From 1d3326cb6696e760e4b87a04668eb522d19673c4 Mon Sep 17 00:00:00 2001 From: Fei Hu <hufei68@gmail.com> Date: Wed, 4 Sep 2019 15:50:03 -0700 Subject: [PATCH] Address the comments --- tensorflow/core/framework/function_testlib.h | 58 +++++++++---------- .../kernels/data/batch_dataset_op_test.cc | 18 +++--- .../core/kernels/data/dataset_test_base.cc | 31 +++++----- .../core/kernels/data/dataset_test_base.h | 30 ++++++---- 4 files changed, 71 insertions(+), 66 deletions(-) diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index eb5df661e56..06b9d48c289 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -68,16 +68,16 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes, // For testing convenience, we provide a few simple functions that can // be easily executed and tested. -// x:T -> x * 2. +// x: T -> x * 2. FunctionDef XTimesTwo(); -// x:T -> cpu(x * 2) + cpu(x * 3). +// x: T -> cpu(x * 2) + cpu(x * 3). FunctionDef TwoDeviceTimesFive(); -// x:T -> cpu(x * 2), gpu(x * 3). +// x: T -> cpu(x * 2), gpu(x * 3). FunctionDef TwoDeviceMult(); -// cpu(x):T, gpu(y):T -> cpu(x * 2), gpu(y * 3). +// cpu(x): T, gpu(y): T -> cpu(x * 2), gpu(y * 3). FunctionDef TwoDeviceInputOutput(); // Function taking a list of Tensors as input. @@ -86,25 +86,25 @@ FunctionDef FuncWithListInput(); // Function returning a list of Tensors as output. FunctionDef FuncWithListOutput(); -// x:T -> x + x. +// x: T -> x + x. FunctionDef XAddX(); -// x: T, y:T -> x + y. +// x: T, y: T -> x + y. FunctionDef XAddY(); -// x:T -> x * 2, where x is int32. +// x: T -> x * 2, where x is int32. FunctionDef XTimesTwoInt32(); -// x:T -> (x * 2) * 2. +// x: T -> (x * 2) * 2. FunctionDef XTimesFour(); -// x:T -> ((x * 2) * 2) * 2. +// x: T -> ((x * 2) * 2) * 2. FunctionDef XTimes16(); -// w:T, x:T, b:T -> MatMul(w, x) + b +// w: T, x: T, b: T -> MatMul(w, x) + b FunctionDef WXPlusB(); -// x:T -> x:T, T is a type which we automatically converts to a bool. +// x: T -> x: T, T is a type which we automatically converts to a bool. FunctionDef NonZero(); // x: T -> bool. @@ -113,56 +113,56 @@ FunctionDef IsZero(); // x: T -> int64 FunctionDef RandomUniform(); -// x:T, y:T -> y:T, x:T +// x: T, y:T -> y: T, x: T FunctionDef Swap(); -// x:T, y:T -> y:T, x:T, the body has no nodes. +// x: T, y: T -> y: T, x: T, the body has no nodes. FunctionDef EmptyBodySwap(); -// x:float, y:resource -> y:resource, 2*x:float. +// x: float, y: resource -> y: resource, 2*x: float. FunctionDef ResourceOutput(); -// x:resource -> x:resource +// x: resource -> x: resource FunctionDef ResourceIdentity(); -// x:resource -> y:float. +// x: resource -> y: float. FunctionDef ReadResourceVariable(); // Contains malformed control flow which can't be run by the executor. FunctionDef InvalidControlFlow(); -// x:T -> x <= N. +// x: T -> x <= N. FunctionDef LessThanOrEqualToN(int64 N); -// x:T, y:T -> x+1, x*y +// x: T, y: T -> x + 1, x * y FunctionDef XPlusOneXTimesY(); -// x:T, y:T -> x <= N +// x: T, y: T -> x <= N FunctionDef XYXLessThanOrEqualToN(int64 N); // x: T -> bool FunctionDef RandomUniformLess(); -// start:int64, stop:int64, step:int64 -> y:RangeDatasetOp::Dataset +// start: int64, stop: int64, step: int64 -> y: RangeDatasetOp::Dataset FunctionDef MakeRangeDataset(); -// input_dataset:variant, batch_size:int64, drop_remainder:bool -// -> y:BatchDatasetV2::Dataset +// input_dataset: variant, batch_size: int64, drop_remainder: bool +// -> y: BatchDatasetV2::Dataset FunctionDef MakeBatchDataset(); -// input_dataset:variant, other_arguments:Targuments, f:func, -// Targuments:list(type), output_types:list(type), output_shapes:list(shape), -// use_inter_op_parallelism:bool, preserve_cardinality:bool -// -> y:MapDatasetOp::Dataset +// input_dataset: variant, other_arguments: Targuments, f: func, +// Targuments: list(type), output_types: list(type), output_shapes: list(shape), +// use_inter_op_parallelism: bool, preserve_cardinality: bool +// -> y: MapDatasetOp::Dataset FunctionDef MakeMapDataset(bool has_other_args); -// input_dataset:variant, count:int64 -> y:TakeDataset::Dataset +// input_dataset: variant, count: int64 -> y: TakeDataset::Dataset FunctionDef MakeTakeDataset(); -// x:T -> y:TensorSliceDatasetOp::Dataset +// x: T -> y: TensorSliceDatasetOp::Dataset FunctionDef MakeTensorSliceDataset(); -// x:T -> y:T, idx:out_idx +// x: T -> y: T, idx: out_idx FunctionDef Unique(); void FunctionTestSchedClosure(std::function<void()> fn); diff --git a/tensorflow/core/kernels/data/batch_dataset_op_test.cc b/tensorflow/core/kernels/data/batch_dataset_op_test.cc index ecffd7a46bc..2ba4d777c14 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op_test.cc @@ -17,7 +17,7 @@ namespace tensorflow { namespace data { namespace { -constexpr char kNodeName[] = "batch_dataset_v2"; +constexpr char kNodeName[] = "batch_dataset"; class BatchDatasetOpTest : public DatasetOpsTestBaseV2 {}; @@ -30,7 +30,7 @@ BatchDatasetParams BatchDatasetParams1() { /*parallel_copy=*/true, /*output_dtypes=*/{DT_INT64}, /*output_shapes=*/{PartialTensorShape({4})}, - /*node_name=*/"batch_dataset_v2"); + /*node_name=*/"batch_dataset"); } // Test Case 2: test BatchDatasetV2 with `drop_remainder` = true and a batch @@ -42,7 +42,7 @@ BatchDatasetParams BatchDatasetParams2() { /*parallel_copy=*/false, /*output_dtypes=*/{DT_INT64}, /*output_shapes=*/{PartialTensorShape({4})}, - /*node_name=*/"batch_dataset_v2"); + /*node_name=*/"batch_dataset"); } // Test Case 3: test BatchDatasetV2 with `drop_remainder` = false and a batch @@ -54,7 +54,7 @@ BatchDatasetParams BatchDatasetParams3() { /*parallel_copy=*/false, /*output_dtypes=*/{DT_INT64}, /*output_shapes=*/{PartialTensorShape({-1})}, - /*node_name=*/"batch_dataset_0"); + /*node_name=*/"batch_dataset"); } // Test Case 4: test BatchDatasetV2 with `drop_remainder` = true and a batch @@ -66,7 +66,7 @@ BatchDatasetParams BatchDatasetParams4() { /*parallel_copy=*/true, /*output_dtypes=*/{DT_INT64}, /*output_shapes=*/{PartialTensorShape({3})}, - /*node_name=*/"batch_dataset_v2"); + /*node_name=*/"batch_dataset"); } // Test Case 5: test BatchDatasetV2 with `drop_remainder` = true and @@ -78,7 +78,7 @@ BatchDatasetParams BatchDatasetParams5() { /*parallel_copy=*/true, /*output_dtypes=*/{DT_INT64}, /*output_shapes=*/{PartialTensorShape({12})}, - /*node_name=*/kNodeName); + /*node_name=*/"batch_dataset"); } // Test Case 6: test BatchDatasetV2 with `drop_remainder` = false and @@ -90,7 +90,7 @@ BatchDatasetParams BatchDatasetParams6() { /*parallel_copy=*/true, /*output_dtypes=*/{DT_INT64}, /*output_shapes=*/{PartialTensorShape({-1})}, - /*node_name=*/"batch_dataset_v2"); + /*node_name=*/"batch_dataset"); } // Test Case 7: test BatchDatasetV2 with `drop_remainder` = false and @@ -102,7 +102,7 @@ BatchDatasetParams BatchDatasetParams7() { /*parallel_copy=*/false, /*output_dtypes=*/{DT_INT64}, /*output_shapes=*/{PartialTensorShape({4})}, - /*node_name=*/"batch_dataset_v2"); + /*node_name=*/"batch_dataset"); } // Test Case 8: test BatchDatasetV2 with an invalid batch size @@ -113,7 +113,7 @@ BatchDatasetParams InvalidBatchSizeBatchDatasetParams() { /*parallel_copy=*/false, /*output_dtypes=*/{DT_INT64}, /*output_shapes=*/{PartialTensorShape({3})}, - /*node_name=*/"batch_dataset_v2"); + /*node_name=*/"batch_dataset"); } std::vector<GetNextTestCase<BatchDatasetParams>> GetNextTestCases() { diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index a4c00b6f8db..fb5c092812e 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -688,7 +688,7 @@ Status DatasetOpsTestBaseV2::Initialize(DatasetParams& dataset_params) { TF_RETURN_IF_ERROR(MakeDatasetTensor(pair.first.get(), &pair.second)); } gtl::InlinedVector<TensorValue, 4> inputs; - TF_RETURN_IF_ERROR(dataset_params.MakeInputs(&inputs)); + TF_RETURN_IF_ERROR(dataset_params.GetInputs(&inputs)); TF_RETURN_IF_ERROR( CreateDatasetContext(dataset_kernel_.get(), &inputs, &dataset_ctx_)); TF_RETURN_IF_ERROR( @@ -705,9 +705,9 @@ Status DatasetOpsTestBaseV2::MakeDatasetOpKernel( name_utils::OpNameParams params; params.op_version = dataset_params.op_version(); std::vector<string> input_placeholder; - TF_RETURN_IF_ERROR(dataset_params.MakeInputPlaceholder(&input_placeholder)); + TF_RETURN_IF_ERROR(dataset_params.GetInputPlaceholder(&input_placeholder)); AttributeVector attributes; - TF_RETURN_IF_ERROR(dataset_params.MakeAttributes(&attributes)); + TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes)); NodeDef node_def = test::function::NDef( dataset_params.node_name(), name_utils::OpName(ToString(dataset_params.type()), params), @@ -724,9 +724,9 @@ Status DatasetOpsTestBaseV2::MakeDatasetTensor(DatasetParams* dataset_params, } AttributeVector attributes; - TF_RETURN_IF_ERROR(dataset_params->MakeAttributes(&attributes)); + TF_RETURN_IF_ERROR(dataset_params->GetAttributes(&attributes)); gtl::InlinedVector<TensorValue, 4> inputs; - TF_RETURN_IF_ERROR(dataset_params->MakeInputs(&inputs)); + TF_RETURN_IF_ERROR(dataset_params->GetInputs(&inputs)); std::vector<Tensor> input_tensors; for (auto& tensor_value : inputs) { input_tensors.emplace_back(*tensor_value.tensor); @@ -756,7 +756,7 @@ Status DatasetOpsTestBaseV2::MakeDatasetTensorFunc( case DatasetParamsType::Map: { std::vector<string> input_placeholder; TF_RETURN_IF_ERROR( - dataset_params.MakeInputPlaceholder(&input_placeholder)); + dataset_params.GetInputPlaceholder(&input_placeholder)); bool has_other_args = input_placeholder.size() > 1; *fdef = test::function::MakeMapDataset(has_other_args); break; @@ -808,26 +808,26 @@ RangeDatasetParams::RangeDatasetParams(int64 start, int64 stop, int64 step) stop_(CreateTensor<int64>(TensorShape({}), {stop})), step_(CreateTensor<int64>(TensorShape({}), {step})) {} -Status RangeDatasetParams::MakeInputs( +Status RangeDatasetParams::GetInputs( gtl::InlinedVector<TensorValue, 4>* inputs) { *inputs = {TensorValue(&start_), TensorValue(&stop_), TensorValue(&step_)}; return Status::OK(); } -Status RangeDatasetParams::MakeInputPlaceholder( +Status RangeDatasetParams::GetInputPlaceholder( std::vector<string>* input_placeholder) const { *input_placeholder = {RangeDatasetOp::kStart, RangeDatasetOp::kStop, RangeDatasetOp::kStep}; return Status::OK(); } -Status RangeDatasetParams::MakeAttributes(AttributeVector* attr_vector) const { +Status RangeDatasetParams::GetAttributes(AttributeVector* attr_vector) const { *attr_vector = {{RangeDatasetOp::kOutputTypes, output_dtypes_}, {RangeDatasetOp::kOutputShapes, output_shapes_}}; return Status::OK(); } -Status BatchDatasetParams::MakeInputs( +Status BatchDatasetParams::GetInputs( gtl::InlinedVector<TensorValue, 4>* inputs) { inputs->reserve(input_dataset_params_group_.size()); for (auto& pair : input_dataset_params_group_) { @@ -844,7 +844,7 @@ Status BatchDatasetParams::MakeInputs( return Status::OK(); } -Status BatchDatasetParams::MakeInputPlaceholder( +Status BatchDatasetParams::GetInputPlaceholder( std::vector<string>* input_placeholder) const { *input_placeholder = {BatchDatasetOp::kInputDataset, BatchDatasetOp::kBatchSize, @@ -852,7 +852,7 @@ Status BatchDatasetParams::MakeInputPlaceholder( return Status::OK(); } -Status BatchDatasetParams::MakeAttributes(AttributeVector* attr_vector) const { +Status BatchDatasetParams::GetAttributes(AttributeVector* attr_vector) const { *attr_vector = {{BatchDatasetOp::kParallelCopy, parallel_copy_}, {BatchDatasetOp::kOutputTypes, output_dtypes_}, {BatchDatasetOp::kOutputShapes, output_shapes_}}; @@ -861,8 +861,7 @@ Status BatchDatasetParams::MakeAttributes(AttributeVector* attr_vector) const { int BatchDatasetParams::op_version() const { return op_version_; } -Status MapDatasetParams::MakeInputs( - gtl::InlinedVector<TensorValue, 4>* inputs) { +Status MapDatasetParams::GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) { inputs->reserve(input_dataset_params_group_.size()); for (auto& pair : input_dataset_params_group_) { if (!IsDatasetTensor(pair.second)) { @@ -879,7 +878,7 @@ Status MapDatasetParams::MakeInputs( return Status::OK(); } -Status MapDatasetParams::MakeInputPlaceholder( +Status MapDatasetParams::GetInputPlaceholder( std::vector<string>* input_placeholder) const { input_placeholder->emplace_back(MapDatasetOp::kInputDataset); for (int i = 0; i < other_arguments_.size(); ++i) { @@ -889,7 +888,7 @@ Status MapDatasetParams::MakeInputPlaceholder( return Status::OK(); } -Status MapDatasetParams::MakeAttributes(AttributeVector* attr_vector) const { +Status MapDatasetParams::GetAttributes(AttributeVector* attr_vector) const { *attr_vector = { {MapDatasetOp::kFunc, func_}, {MapDatasetOp::kTarguments, type_arguments_}, diff --git a/tensorflow/core/kernels/data/dataset_test_base.h b/tensorflow/core/kernels/data/dataset_test_base.h index b8487744621..0f5257256d6 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.h +++ b/tensorflow/core/kernels/data/dataset_test_base.h @@ -130,14 +130,14 @@ class DatasetParams { ~DatasetParams() {} // Returns the dataset input values as a TensorValue vector. - virtual Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) = 0; + virtual Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) = 0; // Returns the dataset input names as a string vector. - virtual Status MakeInputPlaceholder( + virtual Status GetInputPlaceholder( std::vector<string>* input_placeholder) const = 0; // Returns the dataset attributes as a vector. - virtual Status MakeAttributes(AttributeVector* attributes) const = 0; + virtual Status GetAttributes(AttributeVector* attributes) const = 0; // Checks if the tensor is a dataset variant tensor. static bool IsDatasetTensor(const Tensor& tensor); @@ -177,6 +177,8 @@ class DatasetParams { int op_version_ = 1; }; +// `RangeDatasetParams` is a common dataset parameter type that are used in +// testing. class RangeDatasetParams : public DatasetParams { public: RangeDatasetParams(int64 start, int64 stop, int64 step, @@ -186,12 +188,12 @@ class RangeDatasetParams : public DatasetParams { RangeDatasetParams(int64 start, int64 stop, int64 step); - Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override; + Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override; - Status MakeInputPlaceholder( + Status GetInputPlaceholder( std::vector<string>* input_placeholder) const override; - Status MakeAttributes(AttributeVector* attr_vector) const override; + Status GetAttributes(AttributeVector* attr_vector) const override; private: Tensor start_; @@ -199,6 +201,8 @@ class RangeDatasetParams : public DatasetParams { Tensor step_; }; +// `BatchDatasetParams` is a common dataset parameter type that are used in +// testing. class BatchDatasetParams : public DatasetParams { public: template <typename T> @@ -218,12 +222,12 @@ class BatchDatasetParams : public DatasetParams { std::make_pair(std::move(input_dataset_params_ptr), Tensor())); } - Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override; + Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override; - Status MakeInputPlaceholder( + Status GetInputPlaceholder( std::vector<string>* input_placeholder) const override; - Status MakeAttributes(AttributeVector* attr_vector) const override; + Status GetAttributes(AttributeVector* attr_vector) const override; int op_version() const override; @@ -234,6 +238,8 @@ class BatchDatasetParams : public DatasetParams { int op_version_ = 2; }; +// `MapDatasetParams` is a common dataset parameter type that are used in +// testing. class MapDatasetParams : public DatasetParams { public: template <typename T> @@ -258,12 +264,12 @@ class MapDatasetParams : public DatasetParams { std::make_pair(std::move(input_dataset_params_ptr), Tensor())); } - Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override; + Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override; - Status MakeInputPlaceholder( + Status GetInputPlaceholder( std::vector<string>* input_placeholder) const override; - Status MakeAttributes(AttributeVector* attr_vector) const override; + Status GetAttributes(AttributeVector* attr_vector) const override; std::vector<FunctionDef> func_lib() const override;