Address the comments
This commit is contained in:
parent
b7833a43a8
commit
1d3326cb66
@ -68,16 +68,16 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
|
|||||||
// For testing convenience, we provide a few simple functions that can
|
// For testing convenience, we provide a few simple functions that can
|
||||||
// be easily executed and tested.
|
// be easily executed and tested.
|
||||||
|
|
||||||
// x:T -> x * 2.
|
// x: T -> x * 2.
|
||||||
FunctionDef XTimesTwo();
|
FunctionDef XTimesTwo();
|
||||||
|
|
||||||
// x:T -> cpu(x * 2) + cpu(x * 3).
|
// x: T -> cpu(x * 2) + cpu(x * 3).
|
||||||
FunctionDef TwoDeviceTimesFive();
|
FunctionDef TwoDeviceTimesFive();
|
||||||
|
|
||||||
// x:T -> cpu(x * 2), gpu(x * 3).
|
// x: T -> cpu(x * 2), gpu(x * 3).
|
||||||
FunctionDef TwoDeviceMult();
|
FunctionDef TwoDeviceMult();
|
||||||
|
|
||||||
// cpu(x):T, gpu(y):T -> cpu(x * 2), gpu(y * 3).
|
// cpu(x): T, gpu(y): T -> cpu(x * 2), gpu(y * 3).
|
||||||
FunctionDef TwoDeviceInputOutput();
|
FunctionDef TwoDeviceInputOutput();
|
||||||
|
|
||||||
// Function taking a list of Tensors as input.
|
// Function taking a list of Tensors as input.
|
||||||
@ -86,25 +86,25 @@ FunctionDef FuncWithListInput();
|
|||||||
// Function returning a list of Tensors as output.
|
// Function returning a list of Tensors as output.
|
||||||
FunctionDef FuncWithListOutput();
|
FunctionDef FuncWithListOutput();
|
||||||
|
|
||||||
// x:T -> x + x.
|
// x: T -> x + x.
|
||||||
FunctionDef XAddX();
|
FunctionDef XAddX();
|
||||||
|
|
||||||
// x: T, y:T -> x + y.
|
// x: T, y: T -> x + y.
|
||||||
FunctionDef XAddY();
|
FunctionDef XAddY();
|
||||||
|
|
||||||
// x:T -> x * 2, where x is int32.
|
// x: T -> x * 2, where x is int32.
|
||||||
FunctionDef XTimesTwoInt32();
|
FunctionDef XTimesTwoInt32();
|
||||||
|
|
||||||
// x:T -> (x * 2) * 2.
|
// x: T -> (x * 2) * 2.
|
||||||
FunctionDef XTimesFour();
|
FunctionDef XTimesFour();
|
||||||
|
|
||||||
// x:T -> ((x * 2) * 2) * 2.
|
// x: T -> ((x * 2) * 2) * 2.
|
||||||
FunctionDef XTimes16();
|
FunctionDef XTimes16();
|
||||||
|
|
||||||
// w:T, x:T, b:T -> MatMul(w, x) + b
|
// w: T, x: T, b: T -> MatMul(w, x) + b
|
||||||
FunctionDef WXPlusB();
|
FunctionDef WXPlusB();
|
||||||
|
|
||||||
// x:T -> x:T, T is a type which we automatically converts to a bool.
|
// x: T -> x: T, T is a type which we automatically converts to a bool.
|
||||||
FunctionDef NonZero();
|
FunctionDef NonZero();
|
||||||
|
|
||||||
// x: T -> bool.
|
// x: T -> bool.
|
||||||
@ -113,56 +113,56 @@ FunctionDef IsZero();
|
|||||||
// x: T -> int64
|
// x: T -> int64
|
||||||
FunctionDef RandomUniform();
|
FunctionDef RandomUniform();
|
||||||
|
|
||||||
// x:T, y:T -> y:T, x:T
|
// x: T, y:T -> y: T, x: T
|
||||||
FunctionDef Swap();
|
FunctionDef Swap();
|
||||||
|
|
||||||
// x:T, y:T -> y:T, x:T, the body has no nodes.
|
// x: T, y: T -> y: T, x: T, the body has no nodes.
|
||||||
FunctionDef EmptyBodySwap();
|
FunctionDef EmptyBodySwap();
|
||||||
|
|
||||||
// x:float, y:resource -> y:resource, 2*x:float.
|
// x: float, y: resource -> y: resource, 2*x: float.
|
||||||
FunctionDef ResourceOutput();
|
FunctionDef ResourceOutput();
|
||||||
|
|
||||||
// x:resource -> x:resource
|
// x: resource -> x: resource
|
||||||
FunctionDef ResourceIdentity();
|
FunctionDef ResourceIdentity();
|
||||||
|
|
||||||
// x:resource -> y:float.
|
// x: resource -> y: float.
|
||||||
FunctionDef ReadResourceVariable();
|
FunctionDef ReadResourceVariable();
|
||||||
|
|
||||||
// Contains malformed control flow which can't be run by the executor.
|
// Contains malformed control flow which can't be run by the executor.
|
||||||
FunctionDef InvalidControlFlow();
|
FunctionDef InvalidControlFlow();
|
||||||
|
|
||||||
// x:T -> x <= N.
|
// x: T -> x <= N.
|
||||||
FunctionDef LessThanOrEqualToN(int64 N);
|
FunctionDef LessThanOrEqualToN(int64 N);
|
||||||
|
|
||||||
// x:T, y:T -> x+1, x*y
|
// x: T, y: T -> x + 1, x * y
|
||||||
FunctionDef XPlusOneXTimesY();
|
FunctionDef XPlusOneXTimesY();
|
||||||
|
|
||||||
// x:T, y:T -> x <= N
|
// x: T, y: T -> x <= N
|
||||||
FunctionDef XYXLessThanOrEqualToN(int64 N);
|
FunctionDef XYXLessThanOrEqualToN(int64 N);
|
||||||
|
|
||||||
// x: T -> bool
|
// x: T -> bool
|
||||||
FunctionDef RandomUniformLess();
|
FunctionDef RandomUniformLess();
|
||||||
|
|
||||||
// start:int64, stop:int64, step:int64 -> y:RangeDatasetOp::Dataset
|
// start: int64, stop: int64, step: int64 -> y: RangeDatasetOp::Dataset
|
||||||
FunctionDef MakeRangeDataset();
|
FunctionDef MakeRangeDataset();
|
||||||
|
|
||||||
// input_dataset:variant, batch_size:int64, drop_remainder:bool
|
// input_dataset: variant, batch_size: int64, drop_remainder: bool
|
||||||
// -> y:BatchDatasetV2::Dataset
|
// -> y: BatchDatasetV2::Dataset
|
||||||
FunctionDef MakeBatchDataset();
|
FunctionDef MakeBatchDataset();
|
||||||
|
|
||||||
// input_dataset:variant, other_arguments:Targuments, f:func,
|
// input_dataset: variant, other_arguments: Targuments, f: func,
|
||||||
// Targuments:list(type), output_types:list(type), output_shapes:list(shape),
|
// Targuments: list(type), output_types: list(type), output_shapes: list(shape),
|
||||||
// use_inter_op_parallelism:bool, preserve_cardinality:bool
|
// use_inter_op_parallelism: bool, preserve_cardinality: bool
|
||||||
// -> y:MapDatasetOp::Dataset
|
// -> y: MapDatasetOp::Dataset
|
||||||
FunctionDef MakeMapDataset(bool has_other_args);
|
FunctionDef MakeMapDataset(bool has_other_args);
|
||||||
|
|
||||||
// input_dataset:variant, count:int64 -> y:TakeDataset::Dataset
|
// input_dataset: variant, count: int64 -> y: TakeDataset::Dataset
|
||||||
FunctionDef MakeTakeDataset();
|
FunctionDef MakeTakeDataset();
|
||||||
|
|
||||||
// x:T -> y:TensorSliceDatasetOp::Dataset
|
// x: T -> y: TensorSliceDatasetOp::Dataset
|
||||||
FunctionDef MakeTensorSliceDataset();
|
FunctionDef MakeTensorSliceDataset();
|
||||||
|
|
||||||
// x:T -> y:T, idx:out_idx
|
// x: T -> y: T, idx: out_idx
|
||||||
FunctionDef Unique();
|
FunctionDef Unique();
|
||||||
|
|
||||||
void FunctionTestSchedClosure(std::function<void()> fn);
|
void FunctionTestSchedClosure(std::function<void()> fn);
|
||||||
|
@ -17,7 +17,7 @@ namespace tensorflow {
|
|||||||
namespace data {
|
namespace data {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kNodeName[] = "batch_dataset_v2";
|
constexpr char kNodeName[] = "batch_dataset";
|
||||||
|
|
||||||
class BatchDatasetOpTest : public DatasetOpsTestBaseV2 {};
|
class BatchDatasetOpTest : public DatasetOpsTestBaseV2 {};
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ BatchDatasetParams BatchDatasetParams1() {
|
|||||||
/*parallel_copy=*/true,
|
/*parallel_copy=*/true,
|
||||||
/*output_dtypes=*/{DT_INT64},
|
/*output_dtypes=*/{DT_INT64},
|
||||||
/*output_shapes=*/{PartialTensorShape({4})},
|
/*output_shapes=*/{PartialTensorShape({4})},
|
||||||
/*node_name=*/"batch_dataset_v2");
|
/*node_name=*/"batch_dataset");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test Case 2: test BatchDatasetV2 with `drop_remainder` = true and a batch
|
// Test Case 2: test BatchDatasetV2 with `drop_remainder` = true and a batch
|
||||||
@ -42,7 +42,7 @@ BatchDatasetParams BatchDatasetParams2() {
|
|||||||
/*parallel_copy=*/false,
|
/*parallel_copy=*/false,
|
||||||
/*output_dtypes=*/{DT_INT64},
|
/*output_dtypes=*/{DT_INT64},
|
||||||
/*output_shapes=*/{PartialTensorShape({4})},
|
/*output_shapes=*/{PartialTensorShape({4})},
|
||||||
/*node_name=*/"batch_dataset_v2");
|
/*node_name=*/"batch_dataset");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test Case 3: test BatchDatasetV2 with `drop_remainder` = false and a batch
|
// Test Case 3: test BatchDatasetV2 with `drop_remainder` = false and a batch
|
||||||
@ -54,7 +54,7 @@ BatchDatasetParams BatchDatasetParams3() {
|
|||||||
/*parallel_copy=*/false,
|
/*parallel_copy=*/false,
|
||||||
/*output_dtypes=*/{DT_INT64},
|
/*output_dtypes=*/{DT_INT64},
|
||||||
/*output_shapes=*/{PartialTensorShape({-1})},
|
/*output_shapes=*/{PartialTensorShape({-1})},
|
||||||
/*node_name=*/"batch_dataset_0");
|
/*node_name=*/"batch_dataset");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test Case 4: test BatchDatasetV2 with `drop_remainder` = true and a batch
|
// Test Case 4: test BatchDatasetV2 with `drop_remainder` = true and a batch
|
||||||
@ -66,7 +66,7 @@ BatchDatasetParams BatchDatasetParams4() {
|
|||||||
/*parallel_copy=*/true,
|
/*parallel_copy=*/true,
|
||||||
/*output_dtypes=*/{DT_INT64},
|
/*output_dtypes=*/{DT_INT64},
|
||||||
/*output_shapes=*/{PartialTensorShape({3})},
|
/*output_shapes=*/{PartialTensorShape({3})},
|
||||||
/*node_name=*/"batch_dataset_v2");
|
/*node_name=*/"batch_dataset");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test Case 5: test BatchDatasetV2 with `drop_remainder` = true and
|
// Test Case 5: test BatchDatasetV2 with `drop_remainder` = true and
|
||||||
@ -78,7 +78,7 @@ BatchDatasetParams BatchDatasetParams5() {
|
|||||||
/*parallel_copy=*/true,
|
/*parallel_copy=*/true,
|
||||||
/*output_dtypes=*/{DT_INT64},
|
/*output_dtypes=*/{DT_INT64},
|
||||||
/*output_shapes=*/{PartialTensorShape({12})},
|
/*output_shapes=*/{PartialTensorShape({12})},
|
||||||
/*node_name=*/kNodeName);
|
/*node_name=*/"batch_dataset");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test Case 6: test BatchDatasetV2 with `drop_remainder` = false and
|
// Test Case 6: test BatchDatasetV2 with `drop_remainder` = false and
|
||||||
@ -90,7 +90,7 @@ BatchDatasetParams BatchDatasetParams6() {
|
|||||||
/*parallel_copy=*/true,
|
/*parallel_copy=*/true,
|
||||||
/*output_dtypes=*/{DT_INT64},
|
/*output_dtypes=*/{DT_INT64},
|
||||||
/*output_shapes=*/{PartialTensorShape({-1})},
|
/*output_shapes=*/{PartialTensorShape({-1})},
|
||||||
/*node_name=*/"batch_dataset_v2");
|
/*node_name=*/"batch_dataset");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test Case 7: test BatchDatasetV2 with `drop_remainder` = false and
|
// Test Case 7: test BatchDatasetV2 with `drop_remainder` = false and
|
||||||
@ -102,7 +102,7 @@ BatchDatasetParams BatchDatasetParams7() {
|
|||||||
/*parallel_copy=*/false,
|
/*parallel_copy=*/false,
|
||||||
/*output_dtypes=*/{DT_INT64},
|
/*output_dtypes=*/{DT_INT64},
|
||||||
/*output_shapes=*/{PartialTensorShape({4})},
|
/*output_shapes=*/{PartialTensorShape({4})},
|
||||||
/*node_name=*/"batch_dataset_v2");
|
/*node_name=*/"batch_dataset");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test Case 8: test BatchDatasetV2 with an invalid batch size
|
// Test Case 8: test BatchDatasetV2 with an invalid batch size
|
||||||
@ -113,7 +113,7 @@ BatchDatasetParams InvalidBatchSizeBatchDatasetParams() {
|
|||||||
/*parallel_copy=*/false,
|
/*parallel_copy=*/false,
|
||||||
/*output_dtypes=*/{DT_INT64},
|
/*output_dtypes=*/{DT_INT64},
|
||||||
/*output_shapes=*/{PartialTensorShape({3})},
|
/*output_shapes=*/{PartialTensorShape({3})},
|
||||||
/*node_name=*/"batch_dataset_v2");
|
/*node_name=*/"batch_dataset");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<GetNextTestCase<BatchDatasetParams>> GetNextTestCases() {
|
std::vector<GetNextTestCase<BatchDatasetParams>> GetNextTestCases() {
|
||||||
|
@ -688,7 +688,7 @@ Status DatasetOpsTestBaseV2::Initialize(DatasetParams& dataset_params) {
|
|||||||
TF_RETURN_IF_ERROR(MakeDatasetTensor(pair.first.get(), &pair.second));
|
TF_RETURN_IF_ERROR(MakeDatasetTensor(pair.first.get(), &pair.second));
|
||||||
}
|
}
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
TF_RETURN_IF_ERROR(dataset_params.MakeInputs(&inputs));
|
TF_RETURN_IF_ERROR(dataset_params.GetInputs(&inputs));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
CreateDatasetContext(dataset_kernel_.get(), &inputs, &dataset_ctx_));
|
CreateDatasetContext(dataset_kernel_.get(), &inputs, &dataset_ctx_));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -705,9 +705,9 @@ Status DatasetOpsTestBaseV2::MakeDatasetOpKernel(
|
|||||||
name_utils::OpNameParams params;
|
name_utils::OpNameParams params;
|
||||||
params.op_version = dataset_params.op_version();
|
params.op_version = dataset_params.op_version();
|
||||||
std::vector<string> input_placeholder;
|
std::vector<string> input_placeholder;
|
||||||
TF_RETURN_IF_ERROR(dataset_params.MakeInputPlaceholder(&input_placeholder));
|
TF_RETURN_IF_ERROR(dataset_params.GetInputPlaceholder(&input_placeholder));
|
||||||
AttributeVector attributes;
|
AttributeVector attributes;
|
||||||
TF_RETURN_IF_ERROR(dataset_params.MakeAttributes(&attributes));
|
TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes));
|
||||||
NodeDef node_def = test::function::NDef(
|
NodeDef node_def = test::function::NDef(
|
||||||
dataset_params.node_name(),
|
dataset_params.node_name(),
|
||||||
name_utils::OpName(ToString(dataset_params.type()), params),
|
name_utils::OpName(ToString(dataset_params.type()), params),
|
||||||
@ -724,9 +724,9 @@ Status DatasetOpsTestBaseV2::MakeDatasetTensor(DatasetParams* dataset_params,
|
|||||||
}
|
}
|
||||||
|
|
||||||
AttributeVector attributes;
|
AttributeVector attributes;
|
||||||
TF_RETURN_IF_ERROR(dataset_params->MakeAttributes(&attributes));
|
TF_RETURN_IF_ERROR(dataset_params->GetAttributes(&attributes));
|
||||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
TF_RETURN_IF_ERROR(dataset_params->MakeInputs(&inputs));
|
TF_RETURN_IF_ERROR(dataset_params->GetInputs(&inputs));
|
||||||
std::vector<Tensor> input_tensors;
|
std::vector<Tensor> input_tensors;
|
||||||
for (auto& tensor_value : inputs) {
|
for (auto& tensor_value : inputs) {
|
||||||
input_tensors.emplace_back(*tensor_value.tensor);
|
input_tensors.emplace_back(*tensor_value.tensor);
|
||||||
@ -756,7 +756,7 @@ Status DatasetOpsTestBaseV2::MakeDatasetTensorFunc(
|
|||||||
case DatasetParamsType::Map: {
|
case DatasetParamsType::Map: {
|
||||||
std::vector<string> input_placeholder;
|
std::vector<string> input_placeholder;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
dataset_params.MakeInputPlaceholder(&input_placeholder));
|
dataset_params.GetInputPlaceholder(&input_placeholder));
|
||||||
bool has_other_args = input_placeholder.size() > 1;
|
bool has_other_args = input_placeholder.size() > 1;
|
||||||
*fdef = test::function::MakeMapDataset(has_other_args);
|
*fdef = test::function::MakeMapDataset(has_other_args);
|
||||||
break;
|
break;
|
||||||
@ -808,26 +808,26 @@ RangeDatasetParams::RangeDatasetParams(int64 start, int64 stop, int64 step)
|
|||||||
stop_(CreateTensor<int64>(TensorShape({}), {stop})),
|
stop_(CreateTensor<int64>(TensorShape({}), {stop})),
|
||||||
step_(CreateTensor<int64>(TensorShape({}), {step})) {}
|
step_(CreateTensor<int64>(TensorShape({}), {step})) {}
|
||||||
|
|
||||||
Status RangeDatasetParams::MakeInputs(
|
Status RangeDatasetParams::GetInputs(
|
||||||
gtl::InlinedVector<TensorValue, 4>* inputs) {
|
gtl::InlinedVector<TensorValue, 4>* inputs) {
|
||||||
*inputs = {TensorValue(&start_), TensorValue(&stop_), TensorValue(&step_)};
|
*inputs = {TensorValue(&start_), TensorValue(&stop_), TensorValue(&step_)};
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RangeDatasetParams::MakeInputPlaceholder(
|
Status RangeDatasetParams::GetInputPlaceholder(
|
||||||
std::vector<string>* input_placeholder) const {
|
std::vector<string>* input_placeholder) const {
|
||||||
*input_placeholder = {RangeDatasetOp::kStart, RangeDatasetOp::kStop,
|
*input_placeholder = {RangeDatasetOp::kStart, RangeDatasetOp::kStop,
|
||||||
RangeDatasetOp::kStep};
|
RangeDatasetOp::kStep};
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RangeDatasetParams::MakeAttributes(AttributeVector* attr_vector) const {
|
Status RangeDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
|
||||||
*attr_vector = {{RangeDatasetOp::kOutputTypes, output_dtypes_},
|
*attr_vector = {{RangeDatasetOp::kOutputTypes, output_dtypes_},
|
||||||
{RangeDatasetOp::kOutputShapes, output_shapes_}};
|
{RangeDatasetOp::kOutputShapes, output_shapes_}};
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BatchDatasetParams::MakeInputs(
|
Status BatchDatasetParams::GetInputs(
|
||||||
gtl::InlinedVector<TensorValue, 4>* inputs) {
|
gtl::InlinedVector<TensorValue, 4>* inputs) {
|
||||||
inputs->reserve(input_dataset_params_group_.size());
|
inputs->reserve(input_dataset_params_group_.size());
|
||||||
for (auto& pair : input_dataset_params_group_) {
|
for (auto& pair : input_dataset_params_group_) {
|
||||||
@ -844,7 +844,7 @@ Status BatchDatasetParams::MakeInputs(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BatchDatasetParams::MakeInputPlaceholder(
|
Status BatchDatasetParams::GetInputPlaceholder(
|
||||||
std::vector<string>* input_placeholder) const {
|
std::vector<string>* input_placeholder) const {
|
||||||
*input_placeholder = {BatchDatasetOp::kInputDataset,
|
*input_placeholder = {BatchDatasetOp::kInputDataset,
|
||||||
BatchDatasetOp::kBatchSize,
|
BatchDatasetOp::kBatchSize,
|
||||||
@ -852,7 +852,7 @@ Status BatchDatasetParams::MakeInputPlaceholder(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BatchDatasetParams::MakeAttributes(AttributeVector* attr_vector) const {
|
Status BatchDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
|
||||||
*attr_vector = {{BatchDatasetOp::kParallelCopy, parallel_copy_},
|
*attr_vector = {{BatchDatasetOp::kParallelCopy, parallel_copy_},
|
||||||
{BatchDatasetOp::kOutputTypes, output_dtypes_},
|
{BatchDatasetOp::kOutputTypes, output_dtypes_},
|
||||||
{BatchDatasetOp::kOutputShapes, output_shapes_}};
|
{BatchDatasetOp::kOutputShapes, output_shapes_}};
|
||||||
@ -861,8 +861,7 @@ Status BatchDatasetParams::MakeAttributes(AttributeVector* attr_vector) const {
|
|||||||
|
|
||||||
int BatchDatasetParams::op_version() const { return op_version_; }
|
int BatchDatasetParams::op_version() const { return op_version_; }
|
||||||
|
|
||||||
Status MapDatasetParams::MakeInputs(
|
Status MapDatasetParams::GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) {
|
||||||
gtl::InlinedVector<TensorValue, 4>* inputs) {
|
|
||||||
inputs->reserve(input_dataset_params_group_.size());
|
inputs->reserve(input_dataset_params_group_.size());
|
||||||
for (auto& pair : input_dataset_params_group_) {
|
for (auto& pair : input_dataset_params_group_) {
|
||||||
if (!IsDatasetTensor(pair.second)) {
|
if (!IsDatasetTensor(pair.second)) {
|
||||||
@ -879,7 +878,7 @@ Status MapDatasetParams::MakeInputs(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MapDatasetParams::MakeInputPlaceholder(
|
Status MapDatasetParams::GetInputPlaceholder(
|
||||||
std::vector<string>* input_placeholder) const {
|
std::vector<string>* input_placeholder) const {
|
||||||
input_placeholder->emplace_back(MapDatasetOp::kInputDataset);
|
input_placeholder->emplace_back(MapDatasetOp::kInputDataset);
|
||||||
for (int i = 0; i < other_arguments_.size(); ++i) {
|
for (int i = 0; i < other_arguments_.size(); ++i) {
|
||||||
@ -889,7 +888,7 @@ Status MapDatasetParams::MakeInputPlaceholder(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MapDatasetParams::MakeAttributes(AttributeVector* attr_vector) const {
|
Status MapDatasetParams::GetAttributes(AttributeVector* attr_vector) const {
|
||||||
*attr_vector = {
|
*attr_vector = {
|
||||||
{MapDatasetOp::kFunc, func_},
|
{MapDatasetOp::kFunc, func_},
|
||||||
{MapDatasetOp::kTarguments, type_arguments_},
|
{MapDatasetOp::kTarguments, type_arguments_},
|
||||||
|
@ -130,14 +130,14 @@ class DatasetParams {
|
|||||||
~DatasetParams() {}
|
~DatasetParams() {}
|
||||||
|
|
||||||
// Returns the dataset input values as a TensorValue vector.
|
// Returns the dataset input values as a TensorValue vector.
|
||||||
virtual Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) = 0;
|
virtual Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) = 0;
|
||||||
|
|
||||||
// Returns the dataset input names as a string vector.
|
// Returns the dataset input names as a string vector.
|
||||||
virtual Status MakeInputPlaceholder(
|
virtual Status GetInputPlaceholder(
|
||||||
std::vector<string>* input_placeholder) const = 0;
|
std::vector<string>* input_placeholder) const = 0;
|
||||||
|
|
||||||
// Returns the dataset attributes as a vector.
|
// Returns the dataset attributes as a vector.
|
||||||
virtual Status MakeAttributes(AttributeVector* attributes) const = 0;
|
virtual Status GetAttributes(AttributeVector* attributes) const = 0;
|
||||||
|
|
||||||
// Checks if the tensor is a dataset variant tensor.
|
// Checks if the tensor is a dataset variant tensor.
|
||||||
static bool IsDatasetTensor(const Tensor& tensor);
|
static bool IsDatasetTensor(const Tensor& tensor);
|
||||||
@ -177,6 +177,8 @@ class DatasetParams {
|
|||||||
int op_version_ = 1;
|
int op_version_ = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// `RangeDatasetParams` is a common dataset parameter type that are used in
|
||||||
|
// testing.
|
||||||
class RangeDatasetParams : public DatasetParams {
|
class RangeDatasetParams : public DatasetParams {
|
||||||
public:
|
public:
|
||||||
RangeDatasetParams(int64 start, int64 stop, int64 step,
|
RangeDatasetParams(int64 start, int64 stop, int64 step,
|
||||||
@ -186,12 +188,12 @@ class RangeDatasetParams : public DatasetParams {
|
|||||||
|
|
||||||
RangeDatasetParams(int64 start, int64 stop, int64 step);
|
RangeDatasetParams(int64 start, int64 stop, int64 step);
|
||||||
|
|
||||||
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
|
Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
|
||||||
|
|
||||||
Status MakeInputPlaceholder(
|
Status GetInputPlaceholder(
|
||||||
std::vector<string>* input_placeholder) const override;
|
std::vector<string>* input_placeholder) const override;
|
||||||
|
|
||||||
Status MakeAttributes(AttributeVector* attr_vector) const override;
|
Status GetAttributes(AttributeVector* attr_vector) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Tensor start_;
|
Tensor start_;
|
||||||
@ -199,6 +201,8 @@ class RangeDatasetParams : public DatasetParams {
|
|||||||
Tensor step_;
|
Tensor step_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// `BatchDatasetParams` is a common dataset parameter type that are used in
|
||||||
|
// testing.
|
||||||
class BatchDatasetParams : public DatasetParams {
|
class BatchDatasetParams : public DatasetParams {
|
||||||
public:
|
public:
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -218,12 +222,12 @@ class BatchDatasetParams : public DatasetParams {
|
|||||||
std::make_pair(std::move(input_dataset_params_ptr), Tensor()));
|
std::make_pair(std::move(input_dataset_params_ptr), Tensor()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
|
Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
|
||||||
|
|
||||||
Status MakeInputPlaceholder(
|
Status GetInputPlaceholder(
|
||||||
std::vector<string>* input_placeholder) const override;
|
std::vector<string>* input_placeholder) const override;
|
||||||
|
|
||||||
Status MakeAttributes(AttributeVector* attr_vector) const override;
|
Status GetAttributes(AttributeVector* attr_vector) const override;
|
||||||
|
|
||||||
int op_version() const override;
|
int op_version() const override;
|
||||||
|
|
||||||
@ -234,6 +238,8 @@ class BatchDatasetParams : public DatasetParams {
|
|||||||
int op_version_ = 2;
|
int op_version_ = 2;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// `MapDatasetParams` is a common dataset parameter type that are used in
|
||||||
|
// testing.
|
||||||
class MapDatasetParams : public DatasetParams {
|
class MapDatasetParams : public DatasetParams {
|
||||||
public:
|
public:
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -258,12 +264,12 @@ class MapDatasetParams : public DatasetParams {
|
|||||||
std::make_pair(std::move(input_dataset_params_ptr), Tensor()));
|
std::make_pair(std::move(input_dataset_params_ptr), Tensor()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
|
Status GetInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override;
|
||||||
|
|
||||||
Status MakeInputPlaceholder(
|
Status GetInputPlaceholder(
|
||||||
std::vector<string>* input_placeholder) const override;
|
std::vector<string>* input_placeholder) const override;
|
||||||
|
|
||||||
Status MakeAttributes(AttributeVector* attr_vector) const override;
|
Status GetAttributes(AttributeVector* attr_vector) const override;
|
||||||
|
|
||||||
std::vector<FunctionDef> func_lib() const override;
|
std::vector<FunctionDef> func_lib() const override;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user