diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index 5170bb27498..9c58be108fc 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -15,6 +15,7 @@ cc_library( srcs = ["standalone.cc"], hdrs = ["standalone.h"], deps = [ + "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -28,7 +29,6 @@ tf_cc_test( srcs = ["standalone_test.cc"], deps = [ ":standalone", - "//tensorflow/core:all_kernels", "//tensorflow/core:test", "//tensorflow/core:test_main", ] + tf_protos_all(), diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index d23791e510a..b597fd70add 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -101,6 +101,36 @@ tf_cc_test( ], ) +cc_library( + name = "test_util", + testonly = True, + srcs = ["test_util.cc"], + hdrs = [ + "test_util.h", + ], + data = glob(["testdata/*.pbtxt"]), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/framework:protos_all_cc", + "//tensorflow/core/kernels/data:dataset_test_base", + ], +) + +tf_cc_test( + name = "test_util_test", + srcs = ["test_util_test.cc"], + deps = [ + ":test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/data:standalone", + "//tensorflow/core/kernels/data:dataset_test_base", + ], +) + cc_grpc_library( name = "master_cc_grpc_proto", srcs = [":master_proto"], diff --git a/tensorflow/core/data/service/test_util.cc b/tensorflow/core/data/service/test_util.cc new file mode 100644 index 00000000000..1c8c3c21827 --- /dev/null +++ b/tensorflow/core/data/service/test_util.cc @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/test_util.h" + +#include "tensorflow/core/kernels/data/dataset_test_base.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" + +namespace tensorflow { +namespace data { +namespace test_util { + +namespace { +constexpr char kTestdataDir[] = + "tensorflow/core/data/service/testdata"; + +// Proto content generated by +// +// import tensorflow.compat.v2 as tf +// tf.enable_v2_behavior() +// +// ds = tf.data.Dataset.range(10) +// ds = ds.map(lambda x: x*x) +// g = tf.compat.v1.GraphDef() +// g.ParseFromString(ds._as_serialized_graph().numpy()) +// print(g) +constexpr char kMapGraphDefFile[] = "map_graph_def.pbtxt"; +} // namespace + +Status map_test_case(GraphDefTestCase* test_case) { + std::string filepath = io::JoinPath(kTestdataDir, kMapGraphDefFile); + GraphDef graph_def; + TF_RETURN_IF_ERROR(ReadTextProto(Env::Default(), filepath, &graph_def)); + int num_elements = 10; + std::vector> outputs(num_elements); + for (int i = 0; i < num_elements; ++i) { + outputs[i] = CreateTensors(TensorShape{}, {{i * i}}); + } + *test_case = {"MapGraph", graph_def, outputs}; + return Status::OK(); +} + +} // namespace test_util +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/test_util.h b/tensorflow/core/data/service/test_util.h new file mode 100644 index 00000000000..a6b4514dd01 --- /dev/null +++ b/tensorflow/core/data/service/test_util.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_TEST_UTIL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_TEST_UTIL_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { +namespace test_util { + +struct GraphDefTestCase { + // Name for the test case. + string name; + // A dataset graph. + GraphDef graph_def; + // The expected output from iterating over the dataset represented by the + // graph. + std::vector> output; +}; + +// Fills in the input test_case pointer with test case data representing the +// dataset tf.data.Dataset.range(10).map(lambda x: x*x). Useful for testing +// dataset graph execution. +Status map_test_case(GraphDefTestCase* test_case); + +} // namespace test_util +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_TEST_UTIL_H_ diff --git a/tensorflow/core/data/service/test_util_test.cc b/tensorflow/core/data/service/test_util_test.cc new file mode 100644 index 00000000000..1bd5ab66afa --- /dev/null +++ b/tensorflow/core/data/service/test_util_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/data/service/test_util.h" + +#include "tensorflow/core/data/standalone.h" +#include "tensorflow/core/kernels/data/dataset_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace test_util { + +TEST(TestUtil, MapTestCase) { + GraphDefTestCase test_case; + TF_ASSERT_OK(map_test_case(&test_case)); + standalone::Dataset::Params params; + std::unique_ptr dataset; + TF_ASSERT_OK( + standalone::Dataset::FromGraph(params, test_case.graph_def, &dataset)); + + std::unique_ptr iterator; + TF_ASSERT_OK(dataset->MakeIterator(&iterator)); + + bool end_of_input = false; + + std::vector> result; + while (!end_of_input) { + std::vector outputs; + TF_ASSERT_OK(iterator->GetNext(&outputs, &end_of_input)); + if (!end_of_input) { + result.push_back(outputs); + } + } + ASSERT_EQ(result.size(), test_case.output.size()); + for (int i = 0; i < result.size(); ++i) { + TF_EXPECT_OK(DatasetOpsTestBase::ExpectEqual(result[i], test_case.output[i], + /*compare_order=*/true)); + } +} + +} // namespace test_util +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/testdata/map_graph_def.pbtxt b/tensorflow/core/data/service/testdata/map_graph_def.pbtxt new file mode 100644 index 00000000000..6bd813febd4 --- /dev/null +++ b/tensorflow/core/data/service/testdata/map_graph_def.pbtxt @@ -0,0 +1,225 @@ +node { + name: "Const/_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } +} +node { + name: "Const/_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 10 + } + } + } +} +node { + name: "Const/_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } +} +node { + name: "RangeDataset/_3" + op: "RangeDataset" + input: "Const/_0" + input: "Const/_1" + input: "Const/_2" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT64 + } + } + } +} +node { + name: "MapDataset/_4" + op: "MapDataset" + input: "RangeDataset/_3" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "__inference_Dataset_map_lambda_9" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "preserve_cardinality" + value { + b: true + } + } + attr { + key: "use_inter_op_parallelism" + value { + b: true + } + } +} +node { + name: "dataset" + op: "_Retval" + input: "MapDataset/_4" + attr { + key: "T" + value { + type: DT_VARIANT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +library { + function { + signature { + name: "__inference_Dataset_map_lambda_9" + input_arg { + name: "args_0" + type: DT_INT64 + } + output_arg { + name: "identity" + type: DT_INT64 + } + } + node_def { + name: "mul" + op: "Mul" + input: "args_0" + input: "args_0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + experimental_debug_info { + original_node_names: "mul" + } + } + node_def { + name: "Identity" + op: "Identity" + input: "mul:z:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + experimental_debug_info { + original_node_names: "Identity" + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "args_0" + } + } + } + } + } +} +versions { + producer: 341 + min_consumer: 12 +}