Move some code from the header of dataset_test_base to the cc file. Also makes a

pass on cleaning up the includes to be clean.

This should help with isolating some of the includes in the CC file and
hopefully improve some of the build graph's caching.

PiperOrigin-RevId: 291474465
Change-Id: I534194d5692dcaa41a7d41127fc4e840b88d90e3
This commit is contained in:
Anna R 2020-01-24 17:30:18 -08:00 committed by TensorFlower Gardener
parent 6bc88f5dff
commit e977b1cfa6
3 changed files with 28 additions and 98 deletions

View File

@ -33,12 +33,12 @@ cc_library(
":map_dataset_op",
":name_utils",
":range_dataset_op",
":serialization_utils",
":take_dataset_op",
":tensor_slice_dataset_op",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
@ -46,8 +46,7 @@ cc_library(
"//tensorflow/core:test",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:function_ops",
"//third_party/eigen3",
"@com_google_absl//absl/strings",
"//tensorflow/core/kernels:ops_testutil",
],
)

View File

@ -15,74 +15,23 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_test_base.h"
#include <algorithm>
#include <complex>
#include <functional>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "absl/strings/str_cat.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/data/batch_dataset_op.h"
#include "tensorflow/core/kernels/data/concatenate_dataset_op.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/map_dataset_op.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/kernels/data/range_dataset_op.h"
#include "tensorflow/core/kernels/data/take_dataset_op.h"
#include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/threadpool.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@ -198,20 +147,6 @@ Status IsEqual(const Tensor& t1, const Tensor& t2) {
return Status::OK();
}
DatasetOpsTestBase::DatasetOpsTestBase()
: device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
device_type_(DEVICE_CPU),
cpu_num_(kDefaultCPUNum),
thread_num_(kDefaultThreadNum) {
allocator_ = device_->GetAllocator(AllocatorAttributes());
}
DatasetOpsTestBase::~DatasetOpsTestBase() {
if (dataset_) {
dataset_->Unref();
}
}
Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
switch (a.dtype()) {
#define CASE(DT) \

View File

@ -16,46 +16,32 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
#include <stddef.h>
#include <functional>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/kernels/data/range_dataset_op.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/threadpool.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@ -508,7 +494,13 @@ class TestIterator {
// Helpful functions to test Dataset op kernels.
class DatasetOpsTestBase : public ::testing::Test {
public:
DatasetOpsTestBase();
DatasetOpsTestBase()
: device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
device_type_(DEVICE_CPU),
cpu_num_(kDefaultCPUNum),
thread_num_(kDefaultThreadNum) {
allocator_ = device_->GetAllocator(AllocatorAttributes());
}
// Initializes the runtime and creates a dataset and iterator.
Status Initialize(const DatasetParams& dataset_params);
@ -595,7 +587,11 @@ class DatasetOpsTestBase : public ::testing::Test {
protected:
// Make destructor protected so that DatasetOpsTestBase objects cannot
// be instantiated directly. Only subclasses can be instantiated.
~DatasetOpsTestBase() override;
~DatasetOpsTestBase() override {
if (dataset_) {
dataset_->Unref();
}
}
// Creates a thread pool for parallel tasks.
Status InitThreadPool(int thread_num);