[tf.data service] Add test utils for testing against dataset GraphDefs.
PiperOrigin-RevId: 302086003 Change-Id: I9e8ffcabe7d3fab87559deaceb2d795bed336585
This commit is contained in:
parent
38ca061ed3
commit
f023e26d90
@ -15,6 +15,7 @@ cc_library(
|
|||||||
srcs = ["standalone.cc"],
|
srcs = ["standalone.cc"],
|
||||||
hdrs = ["standalone.h"],
|
hdrs = ["standalone.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/core:all_kernels",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -28,7 +29,6 @@ tf_cc_test(
|
|||||||
srcs = ["standalone_test.cc"],
|
srcs = ["standalone_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":standalone",
|
":standalone",
|
||||||
"//tensorflow/core:all_kernels",
|
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
] + tf_protos_all(),
|
] + tf_protos_all(),
|
||||||
|
@ -101,6 +101,36 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "test_util",
|
||||||
|
testonly = True,
|
||||||
|
srcs = ["test_util.cc"],
|
||||||
|
hdrs = [
|
||||||
|
"test_util.h",
|
||||||
|
],
|
||||||
|
data = glob(["testdata/*.pbtxt"]),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/framework:protos_all_cc",
|
||||||
|
"//tensorflow/core/kernels/data:dataset_test_base",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "test_util_test",
|
||||||
|
srcs = ["test_util_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":test_util",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/data:standalone",
|
||||||
|
"//tensorflow/core/kernels/data:dataset_test_base",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_grpc_library(
|
cc_grpc_library(
|
||||||
name = "master_cc_grpc_proto",
|
name = "master_cc_grpc_proto",
|
||||||
srcs = [":master_proto"],
|
srcs = [":master_proto"],
|
||||||
|
58
tensorflow/core/data/service/test_util.cc
Normal file
58
tensorflow/core/data/service/test_util.cc
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/data/service/test_util.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/path.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
namespace test_util {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
constexpr char kTestdataDir[] =
|
||||||
|
"tensorflow/core/data/service/testdata";
|
||||||
|
|
||||||
|
// Proto content generated by
|
||||||
|
//
|
||||||
|
// import tensorflow.compat.v2 as tf
|
||||||
|
// tf.enable_v2_behavior()
|
||||||
|
//
|
||||||
|
// ds = tf.data.Dataset.range(10)
|
||||||
|
// ds = ds.map(lambda x: x*x)
|
||||||
|
// g = tf.compat.v1.GraphDef()
|
||||||
|
// g.ParseFromString(ds._as_serialized_graph().numpy())
|
||||||
|
// print(g)
|
||||||
|
constexpr char kMapGraphDefFile[] = "map_graph_def.pbtxt";
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status map_test_case(GraphDefTestCase* test_case) {
|
||||||
|
std::string filepath = io::JoinPath(kTestdataDir, kMapGraphDefFile);
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReadTextProto(Env::Default(), filepath, &graph_def));
|
||||||
|
int num_elements = 10;
|
||||||
|
std::vector<std::vector<Tensor>> outputs(num_elements);
|
||||||
|
for (int i = 0; i < num_elements; ++i) {
|
||||||
|
outputs[i] = CreateTensors<int64>(TensorShape{}, {{i * i}});
|
||||||
|
}
|
||||||
|
*test_case = {"MapGraph", graph_def, outputs};
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace test_util
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
44
tensorflow/core/data/service/test_util.h
Normal file
44
tensorflow/core/data/service/test_util.h
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_DATA_SERVICE_TEST_UTIL_H_
|
||||||
|
#define TENSORFLOW_CORE_DATA_SERVICE_TEST_UTIL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
namespace test_util {
|
||||||
|
|
||||||
|
struct GraphDefTestCase {
|
||||||
|
// Name for the test case.
|
||||||
|
string name;
|
||||||
|
// A dataset graph.
|
||||||
|
GraphDef graph_def;
|
||||||
|
// The expected output from iterating over the dataset represented by the
|
||||||
|
// graph.
|
||||||
|
std::vector<std::vector<Tensor>> output;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Fills in the input test_case pointer with test case data representing the
|
||||||
|
// dataset tf.data.Dataset.range(10).map(lambda x: x*x). Useful for testing
|
||||||
|
// dataset graph execution.
|
||||||
|
Status map_test_case(GraphDefTestCase* test_case);
|
||||||
|
|
||||||
|
} // namespace test_util
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_DATA_SERVICE_TEST_UTIL_H_
|
57
tensorflow/core/data/service/test_util_test.cc
Normal file
57
tensorflow/core/data/service/test_util_test.cc
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/data/service/test_util.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/data/standalone.h"
|
||||||
|
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
namespace test_util {
|
||||||
|
|
||||||
|
TEST(TestUtil, MapTestCase) {
|
||||||
|
GraphDefTestCase test_case;
|
||||||
|
TF_ASSERT_OK(map_test_case(&test_case));
|
||||||
|
standalone::Dataset::Params params;
|
||||||
|
std::unique_ptr<standalone::Dataset> dataset;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
standalone::Dataset::FromGraph(params, test_case.graph_def, &dataset));
|
||||||
|
|
||||||
|
std::unique_ptr<standalone::Iterator> iterator;
|
||||||
|
TF_ASSERT_OK(dataset->MakeIterator(&iterator));
|
||||||
|
|
||||||
|
bool end_of_input = false;
|
||||||
|
|
||||||
|
std::vector<std::vector<Tensor>> result;
|
||||||
|
while (!end_of_input) {
|
||||||
|
std::vector<tensorflow::Tensor> outputs;
|
||||||
|
TF_ASSERT_OK(iterator->GetNext(&outputs, &end_of_input));
|
||||||
|
if (!end_of_input) {
|
||||||
|
result.push_back(outputs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_EQ(result.size(), test_case.output.size());
|
||||||
|
for (int i = 0; i < result.size(); ++i) {
|
||||||
|
TF_EXPECT_OK(DatasetOpsTestBase::ExpectEqual(result[i], test_case.output[i],
|
||||||
|
/*compare_order=*/true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace test_util
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
225
tensorflow/core/data/service/testdata/map_graph_def.pbtxt
vendored
Normal file
225
tensorflow/core/data/service/testdata/map_graph_def.pbtxt
vendored
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
node {
|
||||||
|
name: "Const/_0"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_INT64
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
int64_val: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Const/_1"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_INT64
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
int64_val: 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Const/_2"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_INT64
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
int64_val: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "RangeDataset/_3"
|
||||||
|
op: "RangeDataset"
|
||||||
|
input: "Const/_0"
|
||||||
|
input: "Const/_1"
|
||||||
|
input: "Const/_2"
|
||||||
|
attr {
|
||||||
|
key: "output_shapes"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "output_types"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "MapDataset/_4"
|
||||||
|
op: "MapDataset"
|
||||||
|
input: "RangeDataset/_3"
|
||||||
|
attr {
|
||||||
|
key: "Targuments"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "f"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "__inference_Dataset_map_lambda_9"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "output_shapes"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "output_types"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "preserve_cardinality"
|
||||||
|
value {
|
||||||
|
b: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "use_inter_op_parallelism"
|
||||||
|
value {
|
||||||
|
b: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "dataset"
|
||||||
|
op: "_Retval"
|
||||||
|
input: "MapDataset/_4"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_VARIANT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "index"
|
||||||
|
value {
|
||||||
|
i: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
library {
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "__inference_Dataset_map_lambda_9"
|
||||||
|
input_arg {
|
||||||
|
name: "args_0"
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "identity"
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "mul"
|
||||||
|
op: "Mul"
|
||||||
|
input: "args_0"
|
||||||
|
input: "args_0"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
original_node_names: "mul"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "Identity"
|
||||||
|
op: "Identity"
|
||||||
|
input: "mul:z:0"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
original_node_names: "Identity"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret {
|
||||||
|
key: "identity"
|
||||||
|
value: "Identity:output:0"
|
||||||
|
}
|
||||||
|
arg_attr {
|
||||||
|
key: 0
|
||||||
|
value {
|
||||||
|
attr {
|
||||||
|
key: "_output_shapes"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "_user_specified_name"
|
||||||
|
value {
|
||||||
|
s: "args_0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versions {
|
||||||
|
producer: 341
|
||||||
|
min_consumer: 12
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user