Replace a couple of Enable$FOO(bool) patterns with an Enable and a Disable function; NFC

IMO the end result is more readable.

PiperOrigin-RevId: 326694735
Change-Id: I6fc01d18859fedbc5be9cdca5d928dfd8e56c463
This commit is contained in:
Sanjoy Das 2020-08-14 11:27:05 -07:00 committed by TensorFlower Gardener
parent 02d0c1158c
commit e1ea20fafb
16 changed files with 54 additions and 42 deletions

View File

@ -183,7 +183,7 @@ class DirectSessionFactory : public SessionFactory {
// Must do this before the CPU allocator is created. // Must do this before the CPU allocator is created.
if (options.config.graph_options().build_cost_model() > 0) { if (options.config.graph_options().build_cost_model() > 0) {
EnableCPUAllocatorFullStats(true); EnableCPUAllocatorFullStats();
} }
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(

View File

@ -56,9 +56,7 @@ Allocator::~Allocator() {}
// If true, cpu allocator collects full stats. // If true, cpu allocator collects full stats.
static bool cpu_allocator_collect_full_stats = false; static bool cpu_allocator_collect_full_stats = false;
void EnableCPUAllocatorFullStats(bool enable) { void EnableCPUAllocatorFullStats() { cpu_allocator_collect_full_stats = true; }
cpu_allocator_collect_full_stats = enable;
}
bool CPUAllocatorFullStatsEnabled() { return cpu_allocator_collect_full_stats; } bool CPUAllocatorFullStatsEnabled() { return cpu_allocator_collect_full_stats; }
string AllocatorAttributes::DebugString() const { string AllocatorAttributes::DebugString() const {

View File

@ -410,14 +410,17 @@ Allocator* cpu_allocator_base();
// call it directly. // call it directly.
Allocator* cpu_allocator(int numa_node = port::kNUMANoAffinity); Allocator* cpu_allocator(int numa_node = port::kNUMANoAffinity);
// If 'enable' is true, the default CPU allocator implementation will collect // Enables AllocatorStats in the default CPU allocator implementation. By
// AllocatorStats. By default, it's disabled. // default, it's disabled.
void EnableCPUAllocatorStats(bool enable); void EnableCPUAllocatorStats();
// Disables AllocatorStats in the default CPU allocator implementation. By
// default, it's disabled.
void DisableCPUAllocatorStats();
bool CPUAllocatorStatsEnabled(); bool CPUAllocatorStatsEnabled();
// If 'enable' is true, the default CPU allocator implementation will collect // Enables full statistics collection in the default CPU allocator
// full statistics. By default, it's disabled. // implementation. By default, it's disabled.
void EnableCPUAllocatorFullStats(bool enable); void EnableCPUAllocatorFullStats();
bool CPUAllocatorFullStatsEnabled(); bool CPUAllocatorFullStatsEnabled();
// An object that does the underlying suballoc/free of memory for a higher-level // An object that does the underlying suballoc/free of memory for a higher-level

View File

@ -133,7 +133,7 @@ TEST(AllocatorAttributesDeathTest, MergeDifferentScopeIds) {
} }
TEST(CPUAllocatorTest, Simple) { TEST(CPUAllocatorTest, Simple) {
EnableCPUAllocatorStats(true); EnableCPUAllocatorStats();
Allocator* a = cpu_allocator(); Allocator* a = cpu_allocator();
std::vector<void*> ptrs; std::vector<void*> ptrs;
for (int s = 1; s < 1024; s++) { for (int s = 1; s < 1024; s++) {
@ -162,7 +162,7 @@ TEST(CPUAllocatorTest, Simple) {
1048576 * sizeof(double)); 1048576 * sizeof(double));
a->ClearStats(); a->ClearStats();
CheckStats(a, 0, 0, 0, 0); CheckStats(a, 0, 0, 0, 0);
EnableCPUAllocatorStats(false); DisableCPUAllocatorStats();
} }
// Define a struct that we will use to observe behavior in the unit tests // Define a struct that we will use to observe behavior in the unit tests
@ -227,13 +227,13 @@ static void BM_Allocation(int iters, int arg) {
std::vector<int> sizes = {256, 4096, 16384, 524288, 512, 1048576}; std::vector<int> sizes = {256, 4096, 16384, 524288, 512, 1048576};
int size_index = 0; int size_index = 0;
if (arg) EnableCPUAllocatorStats(true); if (arg) EnableCPUAllocatorStats();
while (--iters > 0) { while (--iters > 0) {
int bytes = sizes[size_index++ % sizes.size()]; int bytes = sizes[size_index++ % sizes.size()];
void* p = a->AllocateRaw(1, bytes); void* p = a->AllocateRaw(1, bytes);
a->DeallocateRaw(p); a->DeallocateRaw(p);
} }
if (arg) EnableCPUAllocatorStats(false); if (arg) DisableCPUAllocatorStats();
} }
BENCHMARK(BM_Allocation)->Arg(0)->Arg(1); BENCHMARK(BM_Allocation)->Arg(0)->Arg(1);

View File

@ -29,9 +29,8 @@ namespace tensorflow {
// If true, cpu allocator collects more stats. // If true, cpu allocator collects more stats.
static bool cpu_allocator_collect_stats = false; static bool cpu_allocator_collect_stats = false;
void EnableCPUAllocatorStats(bool enable) { void EnableCPUAllocatorStats() { cpu_allocator_collect_stats = true; }
cpu_allocator_collect_stats = enable; void DisableCPUAllocatorStats() { cpu_allocator_collect_stats = false; }
}
bool CPUAllocatorStatsEnabled() { return cpu_allocator_collect_stats; } bool CPUAllocatorStatsEnabled() { return cpu_allocator_collect_stats; }
static const int kMaxTotalAllocationWarnings = 1; static const int kMaxTotalAllocationWarnings = 1;

View File

@ -103,9 +103,9 @@ class Cluster {
// superset of the devices listed in GetDevices/GetDeviceNames(). // superset of the devices listed in GetDevices/GetDeviceNames().
virtual const DeviceSet* GetDeviceSet() const { return nullptr; } virtual const DeviceSet* GetDeviceSet() const { return nullptr; }
// Enables collecting the allocator stats. Call with enable=true must be made // Enables collecting the allocator stats. If called, must be called before
// before Provision(). // Provision().
virtual Status EnablePeakMemoryStats(bool enable) { virtual Status EnablePeakMemoryStats() {
return errors::Unimplemented(strings ::StrCat( return errors::Unimplemented(strings ::StrCat(
"Peak Memory Stats are not supported on ", type(), " clusters")); "Peak Memory Stats are not supported on ", type(), " clusters"));
} }

View File

@ -202,9 +202,9 @@ Status SingleMachine::Run(const GraphDef& graph_def,
return Status::OK(); return Status::OK();
} }
Status SingleMachine::EnablePeakMemoryStats(bool enable) { Status SingleMachine::EnablePeakMemoryStats() {
EnableCPUAllocatorStats(enable); EnableCPUAllocatorStats();
cpu_allocator_stats_enabled_ = enable; cpu_allocator_stats_enabled_ = true;
// No need to enable GPU allocator stats since its stats are always collected. // No need to enable GPU allocator stats since its stats are always collected.
return Status::OK(); return Status::OK();
} }

View File

@ -45,7 +45,7 @@ class SingleMachine : public Cluster {
const DeviceSet* GetDeviceSet() const override { return device_set_.get(); } const DeviceSet* GetDeviceSet() const override { return device_set_.get(); }
Status EnablePeakMemoryStats(bool enable) override; Status EnablePeakMemoryStats() override;
// It requires EnableAllocatorStats(true) be called before Provision(). // It requires EnableAllocatorStats(true) be called before Provision().
Status GetPeakMemoryUsage( Status GetPeakMemoryUsage(

View File

@ -51,7 +51,7 @@ class SingleMachineTest : public ::testing::Test {
#endif #endif
cluster_.reset( cluster_.reset(
new SingleMachine(timeout_s, 3 /* num_cpu_cores */, 0 /* num_gpus */)); new SingleMachine(timeout_s, 3 /* num_cpu_cores */, 0 /* num_gpus */));
TF_CHECK_OK(cluster_->EnablePeakMemoryStats(true)); TF_CHECK_OK(cluster_->EnablePeakMemoryStats());
TF_CHECK_OK(cluster_->Provision()); TF_CHECK_OK(cluster_->Provision());
} }

View File

@ -414,7 +414,7 @@ TEST(GraphTransferer,
GraphTransferer gt; GraphTransferer gt;
gt.EnableStrictCheckMode(false); gt.EnableStrictCheckMode(false);
profile_utils::CpuUtils::EnableClockCycleProfiling(true); profile_utils::CpuUtils::EnableClockCycleProfiling();
ClockCycleProfiler prof; ClockCycleProfiler prof;
prof.Start(); prof.Start();
Status status = gt.LoadGraphFromProtoFile( Status status = gt.LoadGraphFromProtoFile(
@ -447,7 +447,7 @@ TEST(GraphTransferer,
GraphTransferer gt; GraphTransferer gt;
gt.EnableStrictCheckMode(false); gt.EnableStrictCheckMode(false);
profile_utils::CpuUtils::EnableClockCycleProfiling(true); profile_utils::CpuUtils::EnableClockCycleProfiling();
ClockCycleProfiler prof; ClockCycleProfiler prof;
prof.Start(); prof.Start();
Status status = gt.LoadGraphFromProtoFile( Status status = gt.LoadGraphFromProtoFile(
@ -481,7 +481,7 @@ TEST(GraphTransferer,
GraphTransferer gt; GraphTransferer gt;
gt.EnableStrictCheckMode(false); gt.EnableStrictCheckMode(false);
profile_utils::CpuUtils::EnableClockCycleProfiling(true); profile_utils::CpuUtils::EnableClockCycleProfiling();
ClockCycleProfiler prof; ClockCycleProfiler prof;
prof.Start(); prof.Start();
Status status = gt.LoadGraphFromProtoFile( Status status = gt.LoadGraphFromProtoFile(
@ -540,7 +540,7 @@ TEST(GraphTransferer, DISABLED_RunInceptionV3OnHexagonExampleWithFusedGraph) {
TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) { TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) {
CheckHexagonControllerVersion(); CheckHexagonControllerVersion();
profile_utils::CpuUtils::EnableClockCycleProfiling(true); profile_utils::CpuUtils::EnableClockCycleProfiling();
const IRemoteFusedGraphOpsDefinitions* ops_definitions = const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance(); &HexagonOpsDefinitions::getInstance();

View File

@ -54,12 +54,11 @@ uint64 AndroidArmV7ACpuUtilsHelper::GetCurrentClockCycle() {
return static_cast<uint64>(count); return static_cast<uint64>(count);
} }
void AndroidArmV7ACpuUtilsHelper::EnableClockCycleProfiling(const bool enable) { void AndroidArmV7ACpuUtilsHelper::EnableClockCycleProfiling() {
if (!is_initialized_) { if (!is_initialized_) {
// Initialize here to avoid unnecessary initialization // Initialize here to avoid unnecessary initialization
InitializeInternal(); InitializeInternal();
} }
if (enable) {
const int64 cpu0_scaling_min = ReadCpuFrequencyFile(0, "scaling_min"); const int64 cpu0_scaling_min = ReadCpuFrequencyFile(0, "scaling_min");
const int64 cpu0_scaling_max = ReadCpuFrequencyFile(0, "scaling_max"); const int64 cpu0_scaling_max = ReadCpuFrequencyFile(0, "scaling_max");
if (cpu0_scaling_max != cpu0_scaling_min) { if (cpu0_scaling_max != cpu0_scaling_min) {
@ -69,9 +68,14 @@ void AndroidArmV7ACpuUtilsHelper::EnableClockCycleProfiling(const bool enable) {
} }
ResetClockCycle(); ResetClockCycle();
ioctl(fd_, PERF_EVENT_IOC_ENABLE, 0); ioctl(fd_, PERF_EVENT_IOC_ENABLE, 0);
} else { }
ioctl(fd_, PERF_EVENT_IOC_DISABLE, 0);
void AndroidArmV7ACpuUtilsHelper::DisableClockCycleProfiling() {
if (!is_initialized_) {
// Initialize here to avoid unnecessary initialization
InitializeInternal();
} }
ioctl(fd_, PERF_EVENT_IOC_DISABLE, 0);
} }
int64 AndroidArmV7ACpuUtilsHelper::CalculateCpuFrequency() { int64 AndroidArmV7ACpuUtilsHelper::CalculateCpuFrequency() {

View File

@ -36,7 +36,8 @@ class AndroidArmV7ACpuUtilsHelper : public ICpuUtilsHelper {
AndroidArmV7ACpuUtilsHelper() = default; AndroidArmV7ACpuUtilsHelper() = default;
void ResetClockCycle() final; void ResetClockCycle() final;
uint64 GetCurrentClockCycle() final; uint64 GetCurrentClockCycle() final;
void EnableClockCycleProfiling(bool enable) final; void EnableClockCycleProfiling() final;
void DisableClockCycleProfiling() final;
int64 CalculateCpuFrequency() final; int64 CalculateCpuFrequency() final;
private: private:

View File

@ -58,8 +58,12 @@ static ICpuUtilsHelper* cpu_utils_helper_instance_ = nullptr;
GetCpuUtilsHelperSingletonInstance().ResetClockCycle(); GetCpuUtilsHelperSingletonInstance().ResetClockCycle();
} }
/* static */ void CpuUtils::EnableClockCycleProfiling(const bool enable) { /* static */ void CpuUtils::EnableClockCycleProfiling() {
GetCpuUtilsHelperSingletonInstance().EnableClockCycleProfiling(enable); GetCpuUtilsHelperSingletonInstance().EnableClockCycleProfiling();
}
/* static */ void CpuUtils::DisableClockCycleProfiling() {
GetCpuUtilsHelperSingletonInstance().DisableClockCycleProfiling();
} }
/* static */ std::chrono::duration<double> CpuUtils::ConvertClockCycleToTime( /* static */ std::chrono::duration<double> CpuUtils::ConvertClockCycleToTime(

View File

@ -138,9 +138,10 @@ class CpuUtils {
// clock cycle counters from overflowing on some platforms. // clock cycle counters from overflowing on some platforms.
static void ResetClockCycle(); static void ResetClockCycle();
// Enable clock cycle profile // Enable/Disable clock cycle profile
// You can enable / disable profile if it's supported by the platform // You can enable / disable profile if it's supported by the platform
static void EnableClockCycleProfiling(bool enable); static void EnableClockCycleProfiling();
static void DisableClockCycleProfiling();
// Return chrono::duration per each clock // Return chrono::duration per each clock
static std::chrono::duration<double> ConvertClockCycleToTime( static std::chrono::duration<double> ConvertClockCycleToTime(
@ -152,7 +153,8 @@ class CpuUtils {
DefaultCpuUtilsHelper() = default; DefaultCpuUtilsHelper() = default;
void ResetClockCycle() final {} void ResetClockCycle() final {}
uint64 GetCurrentClockCycle() final { return DUMMY_CYCLE_CLOCK; } uint64 GetCurrentClockCycle() final { return DUMMY_CYCLE_CLOCK; }
void EnableClockCycleProfiling(bool /* enable */) final {} void EnableClockCycleProfiling() final {}
void DisableClockCycleProfiling() final {}
int64 CalculateCpuFrequency() final { return INVALID_FREQUENCY; } int64 CalculateCpuFrequency() final { return INVALID_FREQUENCY; }
private: private:

View File

@ -26,7 +26,7 @@ static constexpr bool DBG = false;
class CpuUtilsTest : public ::testing::Test { class CpuUtilsTest : public ::testing::Test {
protected: protected:
void SetUp() override { CpuUtils::EnableClockCycleProfiling(true); } void SetUp() override { CpuUtils::EnableClockCycleProfiling(); }
}; };
TEST_F(CpuUtilsTest, SetUpTestCase) {} TEST_F(CpuUtilsTest, SetUpTestCase) {}

View File

@ -35,9 +35,10 @@ class ICpuUtilsHelper {
virtual void ResetClockCycle() = 0; virtual void ResetClockCycle() = 0;
// Return current clock cycle. // Return current clock cycle.
virtual uint64 GetCurrentClockCycle() = 0; virtual uint64 GetCurrentClockCycle() = 0;
// Enable clock cycle profile // Enable/Disable clock cycle profile
// You can enable / disable profile if it's supported by the platform // You can enable / disable profile if it's supported by the platform
virtual void EnableClockCycleProfiling(bool enable) = 0; virtual void EnableClockCycleProfiling() = 0;
virtual void DisableClockCycleProfiling() = 0;
// Return cpu frequency. // Return cpu frequency.
// CAVEAT: as this method may read file and/or call system calls, // CAVEAT: as this method may read file and/or call system calls,
// this call is supposed to be slow. // this call is supposed to be slow.