Implement more context methods for tfrt and turn on context_test.
PiperOrigin-RevId: 334035897 Change-Id: I5ae4d51ed9e9c4777095cba781be5ec80f8555f8
This commit is contained in:
parent
85ba5dd127
commit
517e851ee1
@ -904,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
|
||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetThreadLocalDevicePlacementPolicy(
|
||||
tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
||||
}
|
||||
|
||||
@ -915,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
// safe to call this function from the async EagerExecutor threads.
|
||||
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
||||
context->GetDevicePlacementPolicy());
|
||||
tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||
@ -1429,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||
}
|
||||
|
||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return context->FindFunctionDef(name) != nullptr;
|
||||
return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
|
||||
}
|
||||
|
||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
}
|
||||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
uint64_t TFE_GetContextId(TFE_Context* ctx) {
|
||||
@ -544,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
|
||||
}
|
||||
|
||||
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetExecutorForThread(executor->executor());
|
||||
tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return new TFE_Executor(&context->Executor());
|
||||
return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
|
||||
}
|
||||
|
||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||
context->HostCPU()->parsed_name());
|
||||
tensorflow::unwrap(ctx)->HostCPUParsedName());
|
||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||
void* data = tensorflow::port::Malloc(str.length());
|
||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||
@ -572,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
|
||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto* function_def = context->FindFunctionDef(function_name);
|
||||
auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
|
||||
if (function_def == nullptr) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"Unable to find FunctionDef with name: ", function_name);
|
||||
@ -643,14 +631,10 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
||||
|
||||
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetAllowSoftPlacement(enable);
|
||||
tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
|
||||
}
|
||||
|
||||
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetLogDevicePlacement(enable);
|
||||
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
|
||||
}
|
||||
|
@ -29,8 +29,25 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class EagerExecutor;
|
||||
|
||||
// LINT.IfChange
|
||||
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
|
||||
enum ContextDevicePlacementPolicy {
|
||||
// Running operations with input tensors on the wrong device will fail.
|
||||
DEVICE_PLACEMENT_EXPLICIT = 0,
|
||||
// Copy the tensor to the right device but log a warning.
|
||||
DEVICE_PLACEMENT_WARN = 1,
|
||||
// Silently copy the tensor, which has a performance cost since the operation
|
||||
// will be blocked till the copy completes. This is the default policy.
|
||||
DEVICE_PLACEMENT_SILENT = 2,
|
||||
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||
};
|
||||
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
|
||||
|
||||
// Abstract interface to a context.
|
||||
//
|
||||
@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
|
||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||
|
||||
// Initialize the step resource container for a training step. This is used
|
||||
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
||||
virtual void StartStep() = 0;
|
||||
// Destroy the step resource container for a training step.
|
||||
virtual void EndStep() = 0;
|
||||
|
||||
// Block until all pending nodes are finished.
|
||||
virtual Status AsyncWait() = 0;
|
||||
|
||||
@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// already exists.
|
||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||
|
||||
// Find and return a added function by its name.
|
||||
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
|
||||
|
||||
// Return the ParsedName of Host CPU device.
|
||||
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
|
||||
|
||||
// Configure soft device placement policy.
|
||||
virtual void SetAllowSoftPlacement(bool enable) = 0;
|
||||
|
||||
// Configure device placement policy logging.
|
||||
virtual void SetLogDevicePlacement(bool enable) = 0;
|
||||
|
||||
// Sets the device placement policy for the current thread.
|
||||
virtual void SetThreadLocalDevicePlacementPolicy(
|
||||
ContextDevicePlacementPolicy policy) = 0;
|
||||
// Returns the device placement policy for the current thread.
|
||||
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Following are legacy features in TF Eager Runtime.
|
||||
// TODO(tf-runtime): Figure out a way to deprecate following features after
|
||||
// migrated to TFRT.
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Clear pending nodes in thread executors and kernel caches.
|
||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||
|
||||
// Initialize the step resource container for a training step. This is used
|
||||
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
||||
virtual void StartStep() = 0;
|
||||
// Destroy the step resource container for a training step.
|
||||
virtual void EndStep() = 0;
|
||||
|
||||
// Return the Eager Executor for current thread. Please note that Eager
|
||||
// Executor is only used in current TF but not in TFRT.
|
||||
virtual EagerExecutor& Executor() = 0;
|
||||
// Update the Eager Executor for current thread.
|
||||
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||
|
||||
// Configure graph collection in RunMetadata.
|
||||
virtual void SetShouldStoreGraphs(bool value) = 0;
|
||||
|
||||
protected:
|
||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||
: AbstractContext(kind) {}
|
||||
|
@ -573,7 +573,7 @@ Status EagerContext::FindFunctionOpData(
|
||||
return func_lib_def_.LookUp(name, op_data);
|
||||
}
|
||||
|
||||
const FunctionDef* EagerContext::FindFunctionDef(const string& name) {
|
||||
const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
|
||||
return func_lib_def_.Find(name);
|
||||
}
|
||||
|
||||
|
@ -79,21 +79,6 @@ namespace eager {
|
||||
class RemoteMgr;
|
||||
} // namespace eager
|
||||
|
||||
// LINT.IfChange
|
||||
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
|
||||
enum ContextDevicePlacementPolicy {
|
||||
// Running operations with input tensors on the wrong device will fail.
|
||||
DEVICE_PLACEMENT_EXPLICIT = 0,
|
||||
// Copy the tensor to the right device but log a warning.
|
||||
DEVICE_PLACEMENT_WARN = 1,
|
||||
// Silently copy the tensor, which has a performance cost since the operation
|
||||
// will be blocked till the copy completes. This is the default policy.
|
||||
DEVICE_PLACEMENT_SILENT = 2,
|
||||
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||
};
|
||||
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
|
||||
|
||||
class RunMetadataListener {
|
||||
public:
|
||||
virtual ~RunMetadataListener() {}
|
||||
@ -186,7 +171,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
std::function<void(std::function<void()>)>* runner() { return &runner_; }
|
||||
|
||||
// Specify a executor for this thread.
|
||||
void SetExecutorForThread(EagerExecutor* executor);
|
||||
void SetExecutorForThread(EagerExecutor* executor) override;
|
||||
|
||||
const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list()
|
||||
const {
|
||||
@ -195,15 +180,16 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
}
|
||||
|
||||
// Clear pending nodes in thread executors and kernel caches.
|
||||
void ClearCachesAndThreadExecutors();
|
||||
void ClearCachesAndThreadExecutors() override;
|
||||
// Clear pending nodes in default executor and kernel caches.
|
||||
void ClearCachesAndDefaultExecutor();
|
||||
|
||||
// Sets the device placement policy for the current thread.
|
||||
void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy);
|
||||
void SetThreadLocalDevicePlacementPolicy(
|
||||
ContextDevicePlacementPolicy policy) override;
|
||||
|
||||
// Returns the device placement policy for the current thread.
|
||||
ContextDevicePlacementPolicy GetDevicePlacementPolicy() const;
|
||||
ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override;
|
||||
|
||||
// Select an appropriate device for an operation.
|
||||
//
|
||||
@ -227,16 +213,19 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
Status FindFunctionOpData(const string& name,
|
||||
const tensorflow::OpRegistrationData** op_data);
|
||||
|
||||
const FunctionDef* FindFunctionDef(const string& name);
|
||||
const FunctionDef* FindFunctionDef(const string& name) const override;
|
||||
|
||||
Device* HostCPU() const { return host_cpu_device_; }
|
||||
Device* CanonicalDevice(Device* d) const {
|
||||
return HostCPU() == d ? nullptr : d;
|
||||
}
|
||||
const DeviceNameUtils::ParsedName& HostCPUParsedName() const override {
|
||||
return HostCPU()->parsed_name();
|
||||
}
|
||||
|
||||
GraphCollector* GetGraphCollector() { return &graph_collector_; }
|
||||
|
||||
EagerExecutor& Executor();
|
||||
EagerExecutor& Executor() override;
|
||||
|
||||
// Add the given `fdef` to the local FunctionLibraryDefinition. And add an
|
||||
// entry to the KernelAndDevice cache for it if it's not exist.
|
||||
@ -267,9 +256,13 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
|
||||
|
||||
bool LogDevicePlacement() const { return log_device_placement_; }
|
||||
void SetLogDevicePlacement(bool enable) { log_device_placement_ = enable; }
|
||||
void SetLogDevicePlacement(bool enable) override {
|
||||
log_device_placement_ = enable;
|
||||
}
|
||||
bool AllowSoftPlacement() const { return allow_soft_placement_; }
|
||||
void SetAllowSoftPlacement(bool enable) { allow_soft_placement_ = enable; }
|
||||
void SetAllowSoftPlacement(bool enable) override {
|
||||
allow_soft_placement_ = enable;
|
||||
}
|
||||
bool LogMemory() const { return log_memory_; }
|
||||
|
||||
Rendezvous* GetRendezvous() const { return rendezvous_; }
|
||||
@ -316,7 +309,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
// TODO(apassos) clean up RunMetadata storage.
|
||||
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
|
||||
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
|
||||
void SetShouldStoreGraphs(bool value);
|
||||
void SetShouldStoreGraphs(bool value) override;
|
||||
RunMetadata* RunMetadataProto() { return &run_metadata_; }
|
||||
void ClearRunMetadata() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);
|
||||
|
||||
|
@ -263,6 +263,7 @@ cuda_py_test(
|
||||
size = "small",
|
||||
srcs = ["context_test.py"],
|
||||
python_version = "PY3",
|
||||
tfrt_enabled = True,
|
||||
deps = [
|
||||
":context",
|
||||
":test",
|
||||
|
@ -75,6 +75,7 @@ class ContextTest(test.TestCase):
|
||||
del tensor2
|
||||
self.assertIs(weak_c(), None)
|
||||
|
||||
@test_util.disable_tfrt('b/169294215: tfrt does not support RunMetadata yet')
|
||||
def testSimpleGraphCollection(self):
|
||||
|
||||
@def_function.function
|
||||
@ -111,20 +112,24 @@ class ContextTest(test.TestCase):
|
||||
_ = context.get_function_def('this_should_not_be_found')
|
||||
|
||||
@test_util.run_gpu_only
|
||||
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
|
||||
def testGetMemoryUsage(self):
|
||||
array_ops.zeros([10]) # Allocate some memory on the GPU.
|
||||
self.assertGreater(
|
||||
context.context().get_total_memory_usage('GPU:0'), 0)
|
||||
|
||||
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
|
||||
def testGetMemoryUsageCPU(self):
|
||||
with self.assertRaisesRegex(ValueError, 'CPU does not support'):
|
||||
context.context().get_total_memory_usage('CPU:0')
|
||||
|
||||
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
|
||||
def testGetMemoryUsageUnknownDevice(self):
|
||||
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
|
||||
context.context().get_total_memory_usage('unknown_device')
|
||||
|
||||
@test_util.run_gpu_only
|
||||
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
|
||||
def testGetMemoryUsageAmbiguousDevice(self):
|
||||
if len(context.context().list_physical_devices('GPU')) < 2:
|
||||
self.skipTest('Need at least 2 GPUs')
|
||||
|
Loading…
Reference in New Issue
Block a user