Move some code from the header of dataset_test_base to the cc file. Also makes a
pass on cleaning up the includes to be clean. This should help with isolating some of the includes in the CC file and hopefully improve some of the build graph's caching. PiperOrigin-RevId: 291474465 Change-Id: I534194d5692dcaa41a7d41127fc4e840b88d90e3
This commit is contained in:
parent
6bc88f5dff
commit
e977b1cfa6
@ -33,12 +33,12 @@ cc_library(
|
||||
":map_dataset_op",
|
||||
":name_utils",
|
||||
":range_dataset_op",
|
||||
":serialization_utils",
|
||||
":take_dataset_op",
|
||||
":tensor_slice_dataset_op",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -46,8 +46,7 @@ cc_library(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,74 +15,23 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/control_flow.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/function_handle_cache.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/kernels/data/batch_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/concatenate_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/map_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/kernels/data/range_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/take_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/io/record_writer.h"
|
||||
#include "tensorflow/core/lib/io/zlib_compression_options.h"
|
||||
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -198,20 +147,6 @@ Status IsEqual(const Tensor& t1, const Tensor& t2) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DatasetOpsTestBase::DatasetOpsTestBase()
|
||||
: device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
|
||||
device_type_(DEVICE_CPU),
|
||||
cpu_num_(kDefaultCPUNum),
|
||||
thread_num_(kDefaultThreadNum) {
|
||||
allocator_ = device_->GetAllocator(AllocatorAttributes());
|
||||
}
|
||||
|
||||
DatasetOpsTestBase::~DatasetOpsTestBase() {
|
||||
if (dataset_) {
|
||||
dataset_->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
|
||||
switch (a.dtype()) {
|
||||
#define CASE(DT) \
|
||||
|
@ -16,46 +16,32 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/function_handle_cache.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/kernels/data/range_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/io/zlib_compression_options.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -508,7 +494,13 @@ class TestIterator {
|
||||
// Helpful functions to test Dataset op kernels.
|
||||
class DatasetOpsTestBase : public ::testing::Test {
|
||||
public:
|
||||
DatasetOpsTestBase();
|
||||
DatasetOpsTestBase()
|
||||
: device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
|
||||
device_type_(DEVICE_CPU),
|
||||
cpu_num_(kDefaultCPUNum),
|
||||
thread_num_(kDefaultThreadNum) {
|
||||
allocator_ = device_->GetAllocator(AllocatorAttributes());
|
||||
}
|
||||
|
||||
// Initializes the runtime and creates a dataset and iterator.
|
||||
Status Initialize(const DatasetParams& dataset_params);
|
||||
@ -595,7 +587,11 @@ class DatasetOpsTestBase : public ::testing::Test {
|
||||
protected:
|
||||
// Make destructor protected so that DatasetOpsTestBase objects cannot
|
||||
// be instantiated directly. Only subclasses can be instantiated.
|
||||
~DatasetOpsTestBase() override;
|
||||
~DatasetOpsTestBase() override {
|
||||
if (dataset_) {
|
||||
dataset_->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a thread pool for parallel tasks.
|
||||
Status InitThreadPool(int thread_num);
|
||||
|
Loading…
x
Reference in New Issue
Block a user