Implement more context methods for tfrt and turn on context_test.
PiperOrigin-RevId: 334035897 Change-Id: I5ae4d51ed9e9c4777095cba781be5ec80f8555f8
This commit is contained in:
parent
85ba5dd127
commit
517e851ee1
@ -904,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
|||||||
|
|
||||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetThreadLocalDevicePlacementPolicy(
|
|
||||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -915,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
|||||||
// safe to call this function from the async EagerExecutor threads.
|
// safe to call this function from the async EagerExecutor threads.
|
||||||
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||||
TFE_Context* ctx) {
|
TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
||||||
context->GetDevicePlacementPolicy());
|
tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||||
@ -1429,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||||
tensorflow::EagerContext* context =
|
return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
return context->FindFunctionDef(name) != nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetShouldStoreGraphs(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetShouldStoreGraphs(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetShouldStoreGraphs(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetShouldStoreGraphs(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t TFE_GetContextId(TFE_Context* ctx) {
|
uint64_t TFE_GetContextId(TFE_Context* ctx) {
|
||||||
@ -544,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetExecutorForThread(executor->executor());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
return new TFE_Executor(&context->Executor());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||||
tensorflow::EagerContext* context =
|
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||||
context->HostCPU()->parsed_name());
|
tensorflow::unwrap(ctx)->HostCPUParsedName());
|
||||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||||
void* data = tensorflow::port::Malloc(str.length());
|
void* data = tensorflow::port::Malloc(str.length());
|
||||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||||
@ -572,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
|||||||
|
|
||||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||||
TF_Buffer* buf, TF_Status* status) {
|
TF_Buffer* buf, TF_Status* status) {
|
||||||
tensorflow::EagerContext* context =
|
auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
auto* function_def = context->FindFunctionDef(function_name);
|
|
||||||
if (function_def == nullptr) {
|
if (function_def == nullptr) {
|
||||||
status->status = tensorflow::errors::NotFound(
|
status->status = tensorflow::errors::NotFound(
|
||||||
"Unable to find FunctionDef with name: ", function_name);
|
"Unable to find FunctionDef with name: ", function_name);
|
||||||
@ -643,14 +631,10 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
|||||||
|
|
||||||
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetAllowSoftPlacement(enable);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetLogDevicePlacement(enable);
|
|
||||||
}
|
}
|
||||||
|
@ -29,8 +29,25 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/tstring.h"
|
#include "tensorflow/core/platform/tstring.h"
|
||||||
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
class EagerExecutor;
|
||||||
|
|
||||||
|
// LINT.IfChange
|
||||||
|
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
|
||||||
|
enum ContextDevicePlacementPolicy {
|
||||||
|
// Running operations with input tensors on the wrong device will fail.
|
||||||
|
DEVICE_PLACEMENT_EXPLICIT = 0,
|
||||||
|
// Copy the tensor to the right device but log a warning.
|
||||||
|
DEVICE_PLACEMENT_WARN = 1,
|
||||||
|
// Silently copy the tensor, which has a performance cost since the operation
|
||||||
|
// will be blocked till the copy completes. This is the default policy.
|
||||||
|
DEVICE_PLACEMENT_SILENT = 2,
|
||||||
|
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||||
|
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||||
|
};
|
||||||
|
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
|
||||||
|
|
||||||
// Abstract interface to a context.
|
// Abstract interface to a context.
|
||||||
//
|
//
|
||||||
@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
// List attributes of available devices
|
// List attributes of available devices
|
||||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||||
|
|
||||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
|
||||||
|
|
||||||
// Initialize the step resource container for a training step. This is used
|
|
||||||
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
|
||||||
virtual void StartStep() = 0;
|
|
||||||
// Destroy the step resource container for a training step.
|
|
||||||
virtual void EndStep() = 0;
|
|
||||||
|
|
||||||
// Block until all pending nodes are finished.
|
// Block until all pending nodes are finished.
|
||||||
virtual Status AsyncWait() = 0;
|
virtual Status AsyncWait() = 0;
|
||||||
|
|
||||||
@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
// already exists.
|
// already exists.
|
||||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||||
|
|
||||||
|
// Find and return a added function by its name.
|
||||||
|
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
|
||||||
|
|
||||||
|
// Return the ParsedName of Host CPU device.
|
||||||
|
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
|
||||||
|
|
||||||
|
// Configure soft device placement policy.
|
||||||
|
virtual void SetAllowSoftPlacement(bool enable) = 0;
|
||||||
|
|
||||||
|
// Configure device placement policy logging.
|
||||||
|
virtual void SetLogDevicePlacement(bool enable) = 0;
|
||||||
|
|
||||||
|
// Sets the device placement policy for the current thread.
|
||||||
|
virtual void SetThreadLocalDevicePlacementPolicy(
|
||||||
|
ContextDevicePlacementPolicy policy) = 0;
|
||||||
|
// Returns the device placement policy for the current thread.
|
||||||
|
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
|
||||||
|
|
||||||
// For LLVM style RTTI.
|
// For LLVM style RTTI.
|
||||||
static bool classof(const AbstractContext* ptr) {
|
static bool classof(const AbstractContext* ptr) {
|
||||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Following are legacy features in TF Eager Runtime.
|
||||||
|
// TODO(tf-runtime): Figure out a way to deprecate following features after
|
||||||
|
// migrated to TFRT.
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Clear pending nodes in thread executors and kernel caches.
|
||||||
|
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||||
|
|
||||||
|
// Initialize the step resource container for a training step. This is used
|
||||||
|
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
||||||
|
virtual void StartStep() = 0;
|
||||||
|
// Destroy the step resource container for a training step.
|
||||||
|
virtual void EndStep() = 0;
|
||||||
|
|
||||||
|
// Return the Eager Executor for current thread. Please note that Eager
|
||||||
|
// Executor is only used in current TF but not in TFRT.
|
||||||
|
virtual EagerExecutor& Executor() = 0;
|
||||||
|
// Update the Eager Executor for current thread.
|
||||||
|
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||||
|
|
||||||
|
// Configure graph collection in RunMetadata.
|
||||||
|
virtual void SetShouldStoreGraphs(bool value) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||||
: AbstractContext(kind) {}
|
: AbstractContext(kind) {}
|
||||||
|
@ -573,7 +573,7 @@ Status EagerContext::FindFunctionOpData(
|
|||||||
return func_lib_def_.LookUp(name, op_data);
|
return func_lib_def_.LookUp(name, op_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
const FunctionDef* EagerContext::FindFunctionDef(const string& name) {
|
const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
|
||||||
return func_lib_def_.Find(name);
|
return func_lib_def_.Find(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,21 +79,6 @@ namespace eager {
|
|||||||
class RemoteMgr;
|
class RemoteMgr;
|
||||||
} // namespace eager
|
} // namespace eager
|
||||||
|
|
||||||
// LINT.IfChange
|
|
||||||
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
|
|
||||||
enum ContextDevicePlacementPolicy {
|
|
||||||
// Running operations with input tensors on the wrong device will fail.
|
|
||||||
DEVICE_PLACEMENT_EXPLICIT = 0,
|
|
||||||
// Copy the tensor to the right device but log a warning.
|
|
||||||
DEVICE_PLACEMENT_WARN = 1,
|
|
||||||
// Silently copy the tensor, which has a performance cost since the operation
|
|
||||||
// will be blocked till the copy completes. This is the default policy.
|
|
||||||
DEVICE_PLACEMENT_SILENT = 2,
|
|
||||||
// Placement policy which silently copies int32 tensors but not other dtypes.
|
|
||||||
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
|
||||||
};
|
|
||||||
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
|
|
||||||
|
|
||||||
class RunMetadataListener {
|
class RunMetadataListener {
|
||||||
public:
|
public:
|
||||||
virtual ~RunMetadataListener() {}
|
virtual ~RunMetadataListener() {}
|
||||||
@ -186,7 +171,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
std::function<void(std::function<void()>)>* runner() { return &runner_; }
|
std::function<void(std::function<void()>)>* runner() { return &runner_; }
|
||||||
|
|
||||||
// Specify a executor for this thread.
|
// Specify a executor for this thread.
|
||||||
void SetExecutorForThread(EagerExecutor* executor);
|
void SetExecutorForThread(EagerExecutor* executor) override;
|
||||||
|
|
||||||
const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list()
|
const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list()
|
||||||
const {
|
const {
|
||||||
@ -195,15 +180,16 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Clear pending nodes in thread executors and kernel caches.
|
// Clear pending nodes in thread executors and kernel caches.
|
||||||
void ClearCachesAndThreadExecutors();
|
void ClearCachesAndThreadExecutors() override;
|
||||||
// Clear pending nodes in default executor and kernel caches.
|
// Clear pending nodes in default executor and kernel caches.
|
||||||
void ClearCachesAndDefaultExecutor();
|
void ClearCachesAndDefaultExecutor();
|
||||||
|
|
||||||
// Sets the device placement policy for the current thread.
|
// Sets the device placement policy for the current thread.
|
||||||
void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy);
|
void SetThreadLocalDevicePlacementPolicy(
|
||||||
|
ContextDevicePlacementPolicy policy) override;
|
||||||
|
|
||||||
// Returns the device placement policy for the current thread.
|
// Returns the device placement policy for the current thread.
|
||||||
ContextDevicePlacementPolicy GetDevicePlacementPolicy() const;
|
ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override;
|
||||||
|
|
||||||
// Select an appropriate device for an operation.
|
// Select an appropriate device for an operation.
|
||||||
//
|
//
|
||||||
@ -227,16 +213,19 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
Status FindFunctionOpData(const string& name,
|
Status FindFunctionOpData(const string& name,
|
||||||
const tensorflow::OpRegistrationData** op_data);
|
const tensorflow::OpRegistrationData** op_data);
|
||||||
|
|
||||||
const FunctionDef* FindFunctionDef(const string& name);
|
const FunctionDef* FindFunctionDef(const string& name) const override;
|
||||||
|
|
||||||
Device* HostCPU() const { return host_cpu_device_; }
|
Device* HostCPU() const { return host_cpu_device_; }
|
||||||
Device* CanonicalDevice(Device* d) const {
|
Device* CanonicalDevice(Device* d) const {
|
||||||
return HostCPU() == d ? nullptr : d;
|
return HostCPU() == d ? nullptr : d;
|
||||||
}
|
}
|
||||||
|
const DeviceNameUtils::ParsedName& HostCPUParsedName() const override {
|
||||||
|
return HostCPU()->parsed_name();
|
||||||
|
}
|
||||||
|
|
||||||
GraphCollector* GetGraphCollector() { return &graph_collector_; }
|
GraphCollector* GetGraphCollector() { return &graph_collector_; }
|
||||||
|
|
||||||
EagerExecutor& Executor();
|
EagerExecutor& Executor() override;
|
||||||
|
|
||||||
// Add the given `fdef` to the local FunctionLibraryDefinition. And add an
|
// Add the given `fdef` to the local FunctionLibraryDefinition. And add an
|
||||||
// entry to the KernelAndDevice cache for it if it's not exist.
|
// entry to the KernelAndDevice cache for it if it's not exist.
|
||||||
@ -267,9 +256,13 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
|
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
|
||||||
|
|
||||||
bool LogDevicePlacement() const { return log_device_placement_; }
|
bool LogDevicePlacement() const { return log_device_placement_; }
|
||||||
void SetLogDevicePlacement(bool enable) { log_device_placement_ = enable; }
|
void SetLogDevicePlacement(bool enable) override {
|
||||||
|
log_device_placement_ = enable;
|
||||||
|
}
|
||||||
bool AllowSoftPlacement() const { return allow_soft_placement_; }
|
bool AllowSoftPlacement() const { return allow_soft_placement_; }
|
||||||
void SetAllowSoftPlacement(bool enable) { allow_soft_placement_ = enable; }
|
void SetAllowSoftPlacement(bool enable) override {
|
||||||
|
allow_soft_placement_ = enable;
|
||||||
|
}
|
||||||
bool LogMemory() const { return log_memory_; }
|
bool LogMemory() const { return log_memory_; }
|
||||||
|
|
||||||
Rendezvous* GetRendezvous() const { return rendezvous_; }
|
Rendezvous* GetRendezvous() const { return rendezvous_; }
|
||||||
@ -316,7 +309,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
// TODO(apassos) clean up RunMetadata storage.
|
// TODO(apassos) clean up RunMetadata storage.
|
||||||
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
|
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
|
||||||
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
|
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
|
||||||
void SetShouldStoreGraphs(bool value);
|
void SetShouldStoreGraphs(bool value) override;
|
||||||
RunMetadata* RunMetadataProto() { return &run_metadata_; }
|
RunMetadata* RunMetadataProto() { return &run_metadata_; }
|
||||||
void ClearRunMetadata() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);
|
void ClearRunMetadata() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);
|
||||||
|
|
||||||
|
@ -263,6 +263,7 @@ cuda_py_test(
|
|||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["context_test.py"],
|
srcs = ["context_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
|
tfrt_enabled = True,
|
||||||
deps = [
|
deps = [
|
||||||
":context",
|
":context",
|
||||||
":test",
|
":test",
|
||||||
|
@ -75,6 +75,7 @@ class ContextTest(test.TestCase):
|
|||||||
del tensor2
|
del tensor2
|
||||||
self.assertIs(weak_c(), None)
|
self.assertIs(weak_c(), None)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/169294215: tfrt does not support RunMetadata yet')
|
||||||
def testSimpleGraphCollection(self):
|
def testSimpleGraphCollection(self):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
@ -111,20 +112,24 @@ class ContextTest(test.TestCase):
|
|||||||
_ = context.get_function_def('this_should_not_be_found')
|
_ = context.get_function_def('this_should_not_be_found')
|
||||||
|
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
|
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
|
||||||
def testGetMemoryUsage(self):
|
def testGetMemoryUsage(self):
|
||||||
array_ops.zeros([10]) # Allocate some memory on the GPU.
|
array_ops.zeros([10]) # Allocate some memory on the GPU.
|
||||||
self.assertGreater(
|
self.assertGreater(
|
||||||
context.context().get_total_memory_usage('GPU:0'), 0)
|
context.context().get_total_memory_usage('GPU:0'), 0)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
|
||||||
def testGetMemoryUsageCPU(self):
|
def testGetMemoryUsageCPU(self):
|
||||||
with self.assertRaisesRegex(ValueError, 'CPU does not support'):
|
with self.assertRaisesRegex(ValueError, 'CPU does not support'):
|
||||||
context.context().get_total_memory_usage('CPU:0')
|
context.context().get_total_memory_usage('CPU:0')
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
|
||||||
def testGetMemoryUsageUnknownDevice(self):
|
def testGetMemoryUsageUnknownDevice(self):
|
||||||
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
|
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
|
||||||
context.context().get_total_memory_usage('unknown_device')
|
context.context().get_total_memory_usage('unknown_device')
|
||||||
|
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
|
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
|
||||||
def testGetMemoryUsageAmbiguousDevice(self):
|
def testGetMemoryUsageAmbiguousDevice(self):
|
||||||
if len(context.context().list_physical_devices('GPU')) < 2:
|
if len(context.context().list_physical_devices('GPU')) < 2:
|
||||||
self.skipTest('Need at least 2 GPUs')
|
self.skipTest('Need at least 2 GPUs')
|
||||||
|
Loading…
Reference in New Issue
Block a user