Implement more context methods for tfrt and turn on context_test.

PiperOrigin-RevId: 334035897
Change-Id: I5ae4d51ed9e9c4777095cba781be5ec80f8555f8
This commit is contained in:
Xiao Yu 2020-09-27 12:29:52 -07:00 committed by TensorFlower Gardener
parent 85ba5dd127
commit 517e851ee1
7 changed files with 95 additions and 72 deletions

View File

@ -904,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
void TFE_ContextSetThreadLocalDevicePlacementPolicy( void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy)); static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
} }
@ -915,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
// safe to call this function from the async EagerExecutor threads. // safe to call this function from the async EagerExecutor threads.
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) { TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextDevicePlacementPolicy>( return static_cast<TFE_ContextDevicePlacementPolicy>(
context->GetDevicePlacementPolicy()); tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
} }
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) { TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
@ -1429,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
} }
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
tensorflow::EagerContext* context = return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->FindFunctionDef(name) != nullptr;
} }
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
} }
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
} }
} // extern "C" } // extern "C"

View File

@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
} }
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
} }
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
} }
uint64_t TFE_GetContextId(TFE_Context* ctx) { uint64_t TFE_GetContextId(TFE_Context* ctx) {
@ -544,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
} }
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetExecutorForThread(executor->executor());
} }
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
tensorflow::EagerContext* context = return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return new TFE_Executor(&context->Executor());
} }
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto address_space = tensorflow::DeviceNameUtils::AddressSpace( auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
context->HostCPU()->parsed_name()); tensorflow::unwrap(ctx)->HostCPUParsedName());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length()); void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0); str.copy(static_cast<char*>(data), str.length(), 0);
@ -572,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) { TF_Buffer* buf, TF_Status* status) {
tensorflow::EagerContext* context = auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto* function_def = context->FindFunctionDef(function_name);
if (function_def == nullptr) { if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound( status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name); "Unable to find FunctionDef with name: ", function_name);
@ -643,14 +631,10 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable, void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetAllowSoftPlacement(enable);
} }
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetLogDevicePlacement(enable);
} }

View File

@ -29,8 +29,25 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow { namespace tensorflow {
class EagerExecutor;
// LINT.IfChange
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
enum ContextDevicePlacementPolicy {
// Running operations with input tensors on the wrong device will fail.
DEVICE_PLACEMENT_EXPLICIT = 0,
// Copy the tensor to the right device but log a warning.
DEVICE_PLACEMENT_WARN = 1,
// Silently copy the tensor, which has a performance cost since the operation
// will be blocked till the copy completes. This is the default policy.
DEVICE_PLACEMENT_SILENT = 2,
// Placement policy which silently copies int32 tensors but not other dtypes.
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
};
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
// Abstract interface to a context. // Abstract interface to a context.
// //
@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext {
// List attributes of available devices // List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0; virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Block until all pending nodes are finished. // Block until all pending nodes are finished.
virtual Status AsyncWait() = 0; virtual Status AsyncWait() = 0;
@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext {
// already exists. // already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Find and return a added function by its name.
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
// Return the ParsedName of Host CPU device.
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
// Configure soft device placement policy.
virtual void SetAllowSoftPlacement(bool enable) = 0;
// Configure device placement policy logging.
virtual void SetLogDevicePlacement(bool enable) = 0;
// Sets the device placement policy for the current thread.
virtual void SetThreadLocalDevicePlacementPolicy(
ContextDevicePlacementPolicy policy) = 0;
// Returns the device placement policy for the current thread.
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
// For LLVM style RTTI. // For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) { static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt; return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
} }
//===--------------------------------------------------------------------===//
// Following are legacy features in TF Eager Runtime.
// TODO(tf-runtime): Figure out a way to deprecate following features after
// migrated to TFRT.
//===--------------------------------------------------------------------===//
// Clear pending nodes in thread executors and kernel caches.
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Return the Eager Executor for current thread. Please note that Eager
// Executor is only used in current TF but not in TFRT.
virtual EagerExecutor& Executor() = 0;
// Update the Eager Executor for current thread.
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
// Configure graph collection in RunMetadata.
virtual void SetShouldStoreGraphs(bool value) = 0;
protected: protected:
explicit ImmediateExecutionContext(AbstractContextKind kind) explicit ImmediateExecutionContext(AbstractContextKind kind)
: AbstractContext(kind) {} : AbstractContext(kind) {}

View File

@ -573,7 +573,7 @@ Status EagerContext::FindFunctionOpData(
return func_lib_def_.LookUp(name, op_data); return func_lib_def_.LookUp(name, op_data);
} }
const FunctionDef* EagerContext::FindFunctionDef(const string& name) { const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
return func_lib_def_.Find(name); return func_lib_def_.Find(name);
} }

View File

@ -79,21 +79,6 @@ namespace eager {
class RemoteMgr; class RemoteMgr;
} // namespace eager } // namespace eager
// LINT.IfChange
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
enum ContextDevicePlacementPolicy {
// Running operations with input tensors on the wrong device will fail.
DEVICE_PLACEMENT_EXPLICIT = 0,
// Copy the tensor to the right device but log a warning.
DEVICE_PLACEMENT_WARN = 1,
// Silently copy the tensor, which has a performance cost since the operation
// will be blocked till the copy completes. This is the default policy.
DEVICE_PLACEMENT_SILENT = 2,
// Placement policy which silently copies int32 tensors but not other dtypes.
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
};
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
class RunMetadataListener { class RunMetadataListener {
public: public:
virtual ~RunMetadataListener() {} virtual ~RunMetadataListener() {}
@ -186,7 +171,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
std::function<void(std::function<void()>)>* runner() { return &runner_; } std::function<void(std::function<void()>)>* runner() { return &runner_; }
// Specify a executor for this thread. // Specify a executor for this thread.
void SetExecutorForThread(EagerExecutor* executor); void SetExecutorForThread(EagerExecutor* executor) override;
const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list() const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list()
const { const {
@ -195,15 +180,16 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
} }
// Clear pending nodes in thread executors and kernel caches. // Clear pending nodes in thread executors and kernel caches.
void ClearCachesAndThreadExecutors(); void ClearCachesAndThreadExecutors() override;
// Clear pending nodes in default executor and kernel caches. // Clear pending nodes in default executor and kernel caches.
void ClearCachesAndDefaultExecutor(); void ClearCachesAndDefaultExecutor();
// Sets the device placement policy for the current thread. // Sets the device placement policy for the current thread.
void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy); void SetThreadLocalDevicePlacementPolicy(
ContextDevicePlacementPolicy policy) override;
// Returns the device placement policy for the current thread. // Returns the device placement policy for the current thread.
ContextDevicePlacementPolicy GetDevicePlacementPolicy() const; ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override;
// Select an appropriate device for an operation. // Select an appropriate device for an operation.
// //
@ -227,16 +213,19 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
Status FindFunctionOpData(const string& name, Status FindFunctionOpData(const string& name,
const tensorflow::OpRegistrationData** op_data); const tensorflow::OpRegistrationData** op_data);
const FunctionDef* FindFunctionDef(const string& name); const FunctionDef* FindFunctionDef(const string& name) const override;
Device* HostCPU() const { return host_cpu_device_; } Device* HostCPU() const { return host_cpu_device_; }
Device* CanonicalDevice(Device* d) const { Device* CanonicalDevice(Device* d) const {
return HostCPU() == d ? nullptr : d; return HostCPU() == d ? nullptr : d;
} }
const DeviceNameUtils::ParsedName& HostCPUParsedName() const override {
return HostCPU()->parsed_name();
}
GraphCollector* GetGraphCollector() { return &graph_collector_; } GraphCollector* GetGraphCollector() { return &graph_collector_; }
EagerExecutor& Executor(); EagerExecutor& Executor() override;
// Add the given `fdef` to the local FunctionLibraryDefinition. And add an // Add the given `fdef` to the local FunctionLibraryDefinition. And add an
// entry to the KernelAndDevice cache for it if it's not exist. // entry to the KernelAndDevice cache for it if it's not exist.
@ -267,9 +256,13 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
bool LogDevicePlacement() const { return log_device_placement_; } bool LogDevicePlacement() const { return log_device_placement_; }
void SetLogDevicePlacement(bool enable) { log_device_placement_ = enable; } void SetLogDevicePlacement(bool enable) override {
log_device_placement_ = enable;
}
bool AllowSoftPlacement() const { return allow_soft_placement_; } bool AllowSoftPlacement() const { return allow_soft_placement_; }
void SetAllowSoftPlacement(bool enable) { allow_soft_placement_ = enable; } void SetAllowSoftPlacement(bool enable) override {
allow_soft_placement_ = enable;
}
bool LogMemory() const { return log_memory_; } bool LogMemory() const { return log_memory_; }
Rendezvous* GetRendezvous() const { return rendezvous_; } Rendezvous* GetRendezvous() const { return rendezvous_; }
@ -316,7 +309,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
// TODO(apassos) clean up RunMetadata storage. // TODO(apassos) clean up RunMetadata storage.
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; } mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_); bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
void SetShouldStoreGraphs(bool value); void SetShouldStoreGraphs(bool value) override;
RunMetadata* RunMetadataProto() { return &run_metadata_; } RunMetadata* RunMetadataProto() { return &run_metadata_; }
void ClearRunMetadata() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_); void ClearRunMetadata() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);

View File

@ -263,6 +263,7 @@ cuda_py_test(
size = "small", size = "small",
srcs = ["context_test.py"], srcs = ["context_test.py"],
python_version = "PY3", python_version = "PY3",
tfrt_enabled = True,
deps = [ deps = [
":context", ":context",
":test", ":test",

View File

@ -75,6 +75,7 @@ class ContextTest(test.TestCase):
del tensor2 del tensor2
self.assertIs(weak_c(), None) self.assertIs(weak_c(), None)
@test_util.disable_tfrt('b/169294215: tfrt does not support RunMetadata yet')
def testSimpleGraphCollection(self): def testSimpleGraphCollection(self):
@def_function.function @def_function.function
@ -111,20 +112,24 @@ class ContextTest(test.TestCase):
_ = context.get_function_def('this_should_not_be_found') _ = context.get_function_def('this_should_not_be_found')
@test_util.run_gpu_only @test_util.run_gpu_only
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryUsage(self): def testGetMemoryUsage(self):
array_ops.zeros([10]) # Allocate some memory on the GPU. array_ops.zeros([10]) # Allocate some memory on the GPU.
self.assertGreater( self.assertGreater(
context.context().get_total_memory_usage('GPU:0'), 0) context.context().get_total_memory_usage('GPU:0'), 0)
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryUsageCPU(self): def testGetMemoryUsageCPU(self):
with self.assertRaisesRegex(ValueError, 'CPU does not support'): with self.assertRaisesRegex(ValueError, 'CPU does not support'):
context.context().get_total_memory_usage('CPU:0') context.context().get_total_memory_usage('CPU:0')
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryUsageUnknownDevice(self): def testGetMemoryUsageUnknownDevice(self):
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'): with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
context.context().get_total_memory_usage('unknown_device') context.context().get_total_memory_usage('unknown_device')
@test_util.run_gpu_only @test_util.run_gpu_only
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryUsageAmbiguousDevice(self): def testGetMemoryUsageAmbiguousDevice(self):
if len(context.context().list_physical_devices('GPU')) < 2: if len(context.context().list_physical_devices('GPU')) < 2:
self.skipTest('Need at least 2 GPUs') self.skipTest('Need at least 2 GPUs')