Address the comments
This commit is contained in:
parent
b7833a43a8
commit
1d3326cb66
tensorflow/core
framework
kernels/data
@ -68,16 +68,16 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
|
||||
// For testing convenience, we provide a few simple functions that can
|
||||
// be easily executed and tested.
|
||||
|
||||
// x:T -> x * 2.
|
||||
// x: T -> x * 2.
|
||||
FunctionDef XTimesTwo();
|
||||
|
||||
// x:T -> cpu(x * 2) + cpu(x * 3).
|
||||
// x: T -> cpu(x * 2) + cpu(x * 3).
|
||||
FunctionDef TwoDeviceTimesFive();
|
||||
|
||||
// x:T -> cpu(x * 2), gpu(x * 3).
|
||||
// x: T -> cpu(x * 2), gpu(x * 3).
|
||||
FunctionDef TwoDeviceMult();
|
||||
|
||||
// cpu(x):T, gpu(y):T -> cpu(x * 2), gpu(y * 3).
|
||||
// cpu(x): T, gpu(y): T -> cpu(x * 2), gpu(y * 3).
|
||||
FunctionDef TwoDeviceInputOutput();
|
||||
|
||||
// Function taking a list of Tensors as input.
|
||||
@ -86,25 +86,25 @@ FunctionDef FuncWithListInput();
|
||||
// Function returning a list of Tensors as output.
|
||||
FunctionDef FuncWithListOutput();
|
||||
|
||||
// x:T -> x + x.
|
||||
// x: T -> x + x.
|
||||
FunctionDef XAddX();
|
||||
|
||||
// x: T, y:T -> x + y.
|
||||
// x: T, y: T -> x + y.
|
||||
FunctionDef XAddY();
|
||||
|
||||
// x:T -> x * 2, where x is int32.
|
||||
// x: T -> x * 2, where x is int32.
|
||||
FunctionDef XTimesTwoInt32();
|
||||
|
||||
// x:T -> (x * 2) * 2.
|
||||
// x: T -> (x * 2) * 2.
|
||||
FunctionDef XTimesFour();
|
||||
|
||||
// x:T -> ((x * 2) * 2) * 2.
|
||||
// x: T -> ((x * 2) * 2) * 2.
|
||||
FunctionDef XTimes16();
|
||||
|
||||
// w:T, x:T, b:T -> MatMul(w, x) + b
|
||||
// w: T, x: T, b: T -> MatMul(w, x) + b
|
||||
FunctionDef WXPlusB();
|
||||
|
||||
// x:T -> x:T, T is a type which we automatically converts to a bool.
|
||||
// x: T -> x: T, T is a type which we automatically converts to a bool.
|
||||
FunctionDef NonZero();
|
||||
|
||||
// x: T -> bool.
|
||||
@ -113,56 +113,56 @@ FunctionDef IsZero();
|
||||
// x: T -> int64
|
||||
FunctionDef RandomUniform();
|
||||
|
||||
// x:T, y:T -> y:T, x:T
|
||||
// x: T, y:T -> y: T, x: T
|
||||
FunctionDef Swap();
|
||||
|
||||
// x:T, y:T -> y:T, x:T, the body has no nodes.
|
||||
// x: T, y: T -> y: T, x: T, the body has no nodes.
|
||||
FunctionDef EmptyBodySwap();
|
||||
|
||||
// x:float, y:resource -> y:resource, 2*x:float.
|
||||
// x: float, y: resource -> y: resource, 2*x: float.
|
||||
FunctionDef ResourceOutput();
|
||||
|
||||
// x:resource -> x:resource
|
||||
// x: resource -> x: resource
|
||||
FunctionDef ResourceIdentity();
|
||||
|
||||
// x:resource -> y:float.
|
||||
// x: resource -> y: float.
|
||||
FunctionDef ReadResourceVariable();
|
||||
|
||||
// Contains malformed control flow which can't be run by the executor.
|
||||
FunctionDef InvalidControlFlow();
|
||||
|
||||
// x:T -> x <= N.
|
||||
// x: T -> x <= N.
|
||||
FunctionDef LessThanOrEqualToN(int64 N);
|
||||
|
||||
// x:T, y:T -> x+1, x*y
|
||||
// x: T, y: T -> x + 1, x * y
|
||||
FunctionDef XPlusOneXTimesY();
|
||||
|
||||
// x:T, y:T -> x <= N
|
||||
// x: T, y: T -> x <= N
|
||||
FunctionDef XYXLessThanOrEqualToN(int64 N);
|
||||
|
||||
// x: T -> bool
|
||||
FunctionDef RandomUniformLess();
|
||||
|
||||
// start:int64, stop:int64, step:int64 -> y:RangeDatasetOp::Dataset
|
||||
// start: int64, stop: int64, step: int64 -> y: RangeDatasetOp::Dataset
|
||||
FunctionDef MakeRangeDataset();
|
||||
|
||||
// input_dataset:variant, batch_size:int64, drop_remainder:bool
|
||||
// -> y:BatchDatasetV2::Dataset
|
||||
// input_dataset: variant, batch_size: int64, drop_remainder: bool
|
||||
// -> y: BatchDatasetV2::Dataset
|
||||
FunctionDef MakeBatchDataset();
|
||||
|
||||
// input_dataset:variant, other_arguments:Targuments, f:func,
|
||||
// Targuments:list(type), output_types:list(type), output_shapes:list(shape),
|
||||
// use_inter_op_parallelism:bool, preserve_cardinality:bool
|
||||
// -> y:MapDatasetOp::Dataset
|
||||
// input_dataset: variant, other_arguments: Targuments, f: func,
|
||||
// Targuments: list(type), output_types: list(type), output_shapes: list(shape),
|
||||
// use_inter_op_parallelism: bool, preserve_cardinality: bool
|
||||
// -> y: MapDatasetOp::Dataset
|
||||
FunctionDef MakeMapDataset(bool has_other_args);
|
||||
|
||||
// input_dataset:variant, count:int64 -> y:TakeDataset::Dataset
|
||||
// input_dataset: variant, count: int64 -> y: TakeDataset::Dataset
|
||||
FunctionDef MakeTakeDataset();
|
||||
|
||||
// x:T -> y:TensorSliceDatasetOp::Dataset
|
||||
// x: T -> y: TensorSliceDatasetOp::Dataset
|
||||
FunctionDef MakeTensorSliceDataset();
|
||||
|
||||
// x:T -> y:T, idx:out_idx
|
||||
// x: T -> y: T, idx: out_idx
|
||||
FunctionDef Unique();
|
||||
|
||||
void FunctionTestSchedClosure(std::function<void()> fn);
|
||||
|
@ -17,7 +17,7 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kNodeName[] = "batch_dataset_v2";
|
||||
constexpr char kNodeName[] = "batch_dataset";
|
||||
|
||||
class BatchDatasetOpTest : public DatasetOpsTestBaseV2 {};
|
||||
|
||||
@ -30,7 +30,7 @@ BatchDatasetParams BatchDatasetParams1() {
|
||||
/*parallel_copy=*/true,
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({4})},
|
||||
/*node_name=*/"batch_dataset_v2");
|
||||
/*node_name=*/"batch_dataset");
|
||||
}
|
||||
|
||||
// Test Case 2: test BatchDatasetV2 with `drop_remainder` = true and a batch
|
||||
@ -42,7 +42,7 @@ BatchDatasetParams BatchDatasetParams2() {
|
||||
/*parallel_copy=*/false,
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({4})},
|
||||
/*node_name=*/"batch_dataset_v2");
|
||||
/*node_name=*/"batch_dataset");
|
||||
}
|
||||
|
||||
// Test Case 3: test BatchDatasetV2 with `drop_remainder` = false and a batch
|
||||
@ -54,7 +54,7 @@ BatchDatasetParams BatchDatasetParams3() {
|
||||
/*parallel_copy=*/false,
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({-1})},
|
||||
/*node_name=*/"batch_dataset_0");
|
||||
/*node_name=*/"batch_dataset");
|
||||
}
|
||||
|
||||
// Test Case 4: test BatchDatasetV2 with `drop_remainder` = true and a batch
|
||||
@ -66,7 +66,7 @@ BatchDatasetParams BatchDatasetParams4() {
|
||||
/*parallel_copy=*/true,
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({3})},
|
||||
/*node_name=*/"batch_dataset_v2");
|
||||
/*node_name=*/"batch_dataset");
|
||||
}
|
||||
|
||||
// Test Case 5: test BatchDatasetV2 with `drop_remainder` = true and
|
||||
@ -78,7 +78,7 @@ BatchDatasetParams BatchDatasetParams5() {
|
||||
/*parallel_copy=*/true,
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({12})},
|
||||
/*node_name=*/kNodeName);
|
||||
/*node_name=*/"batch_dataset");
|
||||
}
|
||||
|
||||
// Test Case 6: test BatchDatasetV2 with `drop_remainder` = false and
|
||||
@ -90,7 +90,7 @@ BatchDatasetParams BatchDatasetParams6() {
|
||||
/*parallel_copy=*/true,
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({-1})},
|
||||
/*node_name=*/"batch_dataset_v2");
|
||||
/*node_name=*/"batch_dataset");
|
||||
}
|
||||
|
||||
// Test Case 7: test BatchDatasetV2 with `drop_remainder` = false and
|
||||
@ -102,7 +102,7 @@ BatchDatasetParams BatchDatasetParams7() {
|
||||
/*parallel_copy=*/false,
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({4})},
|
||||
/*node_name=*/"batch_dataset_v2");
|
||||
/*node_name=*/"batch_dataset");
|
||||
}
|
||||
|
||||
// Test Case 8: test BatchDatasetV2 with an invalid batch size
|
||||
@ -113,7 +113,7 @@ BatchDatasetParams InvalidBatchSizeBatchDatasetParams() {
|
||||
/*parallel_copy=*/false,
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({3})},
|
||||
/*node_name=*/"batch_dataset_v2");
|
||||
/*node_name=*/"batch_dataset");
|
||||
}
|
||||
|
||||
std::vector<GetNextTestCase<BatchDatasetParams>> GetNextTestCases() {
|
||||
|
@ -688,7 +688,7 @@ Status DatasetOpsTestBaseV2::Initialize(DatasetParams& dataset_params) {
|
||||
TF_RETURN_IF_ERROR(MakeDatasetTensor(pair.first.get(), &pair.second));
|
||||
}
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
TF_RETURN_IF_ERROR(dataset_params.MakeInputs(&inputs));
|
||||
TF_RETURN_IF_ERROR(dataset_params.GetInputs(&inputs));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateDatasetContext(dataset_kernel_.get(), &inputs, &dataset_ctx_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -705,9 +705,9 @@ Status DatasetOpsTestBaseV2::MakeDatasetOpKernel(
|
||||
name_utils::OpNameParams params;
|
||||
params.op_version = dataset_params.op_version();
|
||||
std::vector<string> input_placeholder;
|
||||
TF_RETURN_IF_ERROR(dataset_params.MakeInputPlaceholder(&input_placeholder));
|
||||
TF_RETURN_IF_ERROR(dataset_params.GetInputPlaceholder(&input_placeholder));
|
||||
AttributeVector attributes;
|
||||
TF_RETURN_IF_ERROR(dataset_params.MakeAttributes(&attributes));
|
||||
TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
|
||||
NodeDef node_def = test::function::NDef(
|
||||
dataset_params.node_name(),
|
||||
name_utils::OpName(ToString(dataset_params.type()), params),
|
||||
@ -724,9 +724,9 @@ Status DatasetOpsTestBaseV2::MakeDatasetTensor(DatasetParams* dataset_params,
|
||||
}
|
||||
|
||||
AttributeVector attributes;
|
||||
TF_RETURN_IF_ERROR(dataset_params->MakeAttributes(&attributes));
|
||||
TF_RETURN_IF_ERROR(dataset_params->GetAttributes(&attributes));
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
TF_RETURN_IF_ERROR(dataset_params->MakeInputs(&inputs));
|
||||
TF_RETURN_IF_ERROR(dataset_params->GetInputs(&inputs));
|
||||
std::vector<Tensor> input_tensors;
|
||||
for (auto& tensor_value : inputs) {
|
||||
input_tensors.emplace_back(*tensor_value.tensor);
|
||||
@ -756,7 +756,7 @@ Status DatasetOpsTestBaseV2::MakeDatasetTensorFunc(
|
||||
case DatasetParamsType::Map: {
|
||||
std::vector<string> input_placeholder;
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset_params.MakeInputPlaceholder(&input_placeholder));
|
||||
dataset_params.GetInputPlaceholder(&input_placeholder));
|
||||
bool has_other_args = input_placeholder.size() > 1;
|
||||
*fdef = test::function::MakeMapDataset(has_other_args);
|
||||
break;
|
||||
@ -808,26 +808,26 @@ RangeDatasetParams::RangeDatasetParams(int64 start, int64 stop, int64 step)
|
||||
stop_(CreateTensor<int64>(TensorShape({}), {stop})),
|
||||
step_(CreateTensor<int64>(TensorShape({}), {step})) {}
|
||||
|
||||
Status RangeDatasetParams::MakeInputs(
|
||||
Status RangeDatasetParams::GetInputs(
|
||||
gtl::InlinedVector<TensorValue, 4>* inputs) {
|
||||
*inputs = {TensorValue(&start_), TensorValue(&stop_), TensorValue(&step_)};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RangeDatasetParams::MakeInputPlaceholder(
|
||||
Status RangeDatasetParams::GetInputPlaceholder(
|
||||
std::vector<string>* input_placeholder) const {
|
||||
*input_placeholder = {RangeDatasetOp::kStart, RangeDatasetOp::kStop,
|
||||
RangeDatasetOp::kStep};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RangeDatasetParams::MakeAttributes(AttributeVector* attr_vector) const {
|
||||
Status RangeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
|
||||
*attr_vector = {{RangeDatasetOp::kOutputTypes, output_dtypes_},
|
||||
{RangeDatasetOp::kOutputShapes, output_shapes_}};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchDatasetParams::MakeInputs(
|
||||
Status BatchDatasetParams::GetInputs(
|
||||
gtl::InlinedVector<TensorValue, 4>* inputs) {
|
||||
inputs->reserve(input_dataset_params_group_.size());
|
||||
for (auto& pair : input_dataset_params_group_) {
|
||||
@ -844,7 +844,7 @@ Status BatchDatasetParams::MakeInputs(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchDatasetParams::MakeInputPlaceholder(
|
||||
Status BatchDatasetParams::GetInputPlaceholder(
|
||||
std::vector<string>* input_placeholder) const {
|
||||
*input_placeholder = {BatchDatasetOp::kInputDataset,
|
||||
BatchDatasetOp::kBatchSize,
|
||||
@ -852,7 +852,7 @@ Status BatchDatasetParams::MakeInputPlaceholder(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchDatasetParams::MakeAttributes(AttributeVector* attr_vector) const {
|
||||
Status BatchDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
|
||||
*attr_vector = {{BatchDatasetOp::kParallelCopy, parallel_copy_},
|
||||
{BatchDatasetOp::kOutputTypes, output_dtypes_},
|
||||
{BatchDatasetOp::kOutputShapes, output_shapes_}};
|
||||
@ -861,8 +861,7 @@ Status BatchDatasetParams::MakeAttributes(AttributeVector* attr_vector) const {
|
||||
|
||||
int BatchDatasetParams::op_version() const { return op_version_; }
|
||||
|
||||
Status MapDatasetParams::MakeInputs(
|
||||
gtl::InlinedVector<TensorValue, 4>* inputs) {
|
||||
Status MapDatasetParams::GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) {
|
||||
inputs->reserve(input_dataset_params_group_.size());
|
||||
for (auto& pair : input_dataset_params_group_) {
|
||||
if (!IsDatasetTensor(pair.second)) {
|
||||
@ -879,7 +878,7 @@ Status MapDatasetParams::MakeInputs(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MapDatasetParams::MakeInputPlaceholder(
|
||||
Status MapDatasetParams::GetInputPlaceholder(
|
||||
std::vector<string>* input_placeholder) const {
|
||||
input_placeholder->emplace_back(MapDatasetOp::kInputDataset);
|
||||
for (int i = 0; i < other_arguments_.size(); ++i) {
|
||||
@ -889,7 +888,7 @@ Status MapDatasetParams::MakeInputPlaceholder(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MapDatasetParams::MakeAttributes(AttributeVector* attr_vector) const {
|
||||
Status MapDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
|
||||
*attr_vector = {
|
||||
{MapDatasetOp::kFunc, func_},
|
||||
{MapDatasetOp::kTarguments, type_arguments_},
|
||||
|
@ -130,14 +130,14 @@ class DatasetParams {
|
||||
~DatasetParams() {}
|
||||
|
||||
// Returns the dataset input values as a TensorValue vector.
|
||||
virtual Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) = 0;
|
||||
virtual Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) = 0;
|
||||
|
||||
// Returns the dataset input names as a string vector.
|
||||
virtual Status MakeInputPlaceholder(
|
||||
virtual Status GetInputPlaceholder(
|
||||
std::vector<string>* input_placeholder) const = 0;
|
||||
|
||||
// Returns the dataset attributes as a vector.
|
||||
virtual Status MakeAttributes(AttributeVector* attributes) const = 0;
|
||||
virtual Status GetAttributes(AttributeVector* attributes) const = 0;
|
||||
|
||||
// Checks if the tensor is a dataset variant tensor.
|
||||
static bool IsDatasetTensor(const Tensor& tensor);
|
||||
@ -177,6 +177,8 @@ class DatasetParams {
|
||||
int op_version_ = 1;
|
||||
};
|
||||
|
||||
// `RangeDatasetParams` is a common dataset parameter type that are used in
|
||||
// testing.
|
||||
class RangeDatasetParams : public DatasetParams {
|
||||
public:
|
||||
RangeDatasetParams(int64 start, int64 stop, int64 step,
|
||||
@ -186,12 +188,12 @@ class RangeDatasetParams : public DatasetParams {
|
||||
|
||||
RangeDatasetParams(int64 start, int64 stop, int64 step);
|
||||
|
||||
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
|
||||
Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
|
||||
|
||||
Status MakeInputPlaceholder(
|
||||
Status GetInputPlaceholder(
|
||||
std::vector<string>* input_placeholder) const override;
|
||||
|
||||
Status MakeAttributes(AttributeVector* attr_vector) const override;
|
||||
Status GetAttributes(AttributeVector* attr_vector) const override;
|
||||
|
||||
private:
|
||||
Tensor start_;
|
||||
@ -199,6 +201,8 @@ class RangeDatasetParams : public DatasetParams {
|
||||
Tensor step_;
|
||||
};
|
||||
|
||||
// `BatchDatasetParams` is a common dataset parameter type that are used in
|
||||
// testing.
|
||||
class BatchDatasetParams : public DatasetParams {
|
||||
public:
|
||||
template <typename T>
|
||||
@ -218,12 +222,12 @@ class BatchDatasetParams : public DatasetParams {
|
||||
std::make_pair(std::move(input_dataset_params_ptr), Tensor()));
|
||||
}
|
||||
|
||||
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
|
||||
Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
|
||||
|
||||
Status MakeInputPlaceholder(
|
||||
Status GetInputPlaceholder(
|
||||
std::vector<string>* input_placeholder) const override;
|
||||
|
||||
Status MakeAttributes(AttributeVector* attr_vector) const override;
|
||||
Status GetAttributes(AttributeVector* attr_vector) const override;
|
||||
|
||||
int op_version() const override;
|
||||
|
||||
@ -234,6 +238,8 @@ class BatchDatasetParams : public DatasetParams {
|
||||
int op_version_ = 2;
|
||||
};
|
||||
|
||||
// `MapDatasetParams` is a common dataset parameter type that are used in
|
||||
// testing.
|
||||
class MapDatasetParams : public DatasetParams {
|
||||
public:
|
||||
template <typename T>
|
||||
@ -258,12 +264,12 @@ class MapDatasetParams : public DatasetParams {
|
||||
std::make_pair(std::move(input_dataset_params_ptr), Tensor()));
|
||||
}
|
||||
|
||||
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
|
||||
Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
|
||||
|
||||
Status MakeInputPlaceholder(
|
||||
Status GetInputPlaceholder(
|
||||
std::vector<string>* input_placeholder) const override;
|
||||
|
||||
Status MakeAttributes(AttributeVector* attr_vector) const override;
|
||||
Status GetAttributes(AttributeVector* attr_vector) const override;
|
||||
|
||||
std::vector<FunctionDef> func_lib() const override;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user