diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 37604951453..91c5a72ad64 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -904,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetThreadLocalDevicePlacementPolicy( + tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy( static_cast(policy)); } @@ -915,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( // safe to call this function from the async EagerExecutor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); return static_cast( - context->GetDevicePlacementPolicy()); + tensorflow::unwrap(ctx)->GetDevicePlacementPolicy()); } TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) { @@ -1429,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, } unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - return context->FindFunctionDef(name) != nullptr; + return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr; } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetShouldStoreGraphs(true); + tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetShouldStoreGraphs(false); + tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false); } } // extern "C" diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index eabb159a631..cc2270755bf 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, } void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetShouldStoreGraphs(true); + tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true); } void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetShouldStoreGraphs(false); + tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false); } uint64_t TFE_GetContextId(TFE_Context* ctx) { @@ -544,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) { } void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetExecutorForThread(executor->executor()); + tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor()); } TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - return new TFE_Executor(&context->Executor()); + return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor()); } void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); auto address_space = tensorflow::DeviceNameUtils::AddressSpace( - context->HostCPU()->parsed_name()); + tensorflow::unwrap(ctx)->HostCPUParsedName()); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); void* data = tensorflow::port::Malloc(str.length()); str.copy(static_cast(data), str.length(), 0); @@ -572,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, TF_Buffer* buf, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - auto* function_def = context->FindFunctionDef(function_name); + auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name); if (function_def == nullptr) { status->status = tensorflow::errors::NotFound( "Unable to find FunctionDef with name: ", function_name); @@ -643,14 +631,10 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx, void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetAllowSoftPlacement(enable); + tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable); } void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetLogDevicePlacement(enable); + tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable); } diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index 02a3320ef65..a3e3857b34b 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -29,8 +29,25 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +class EagerExecutor; + +// LINT.IfChange +// Note: Keep in sync with exported copy of enum in eager/c_api.h. +enum ContextDevicePlacementPolicy { + // Running operations with input tensors on the wrong device will fail. + DEVICE_PLACEMENT_EXPLICIT = 0, + // Copy the tensor to the right device but log a warning. + DEVICE_PLACEMENT_WARN = 1, + // Silently copy the tensor, which has a performance cost since the operation + // will be blocked till the copy completes. This is the default policy. + DEVICE_PLACEMENT_SILENT = 2, + // Placement policy which silently copies int32 tensors but not other dtypes. + DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, +}; +// LINT.ThenChange(//tensorflow/c/eager/c_api.h) // Abstract interface to a context. // @@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext { // List attributes of available devices virtual void ListDevices(std::vector* devices) = 0; - virtual void ClearCachesAndThreadExecutors() = 0; - - // Initialize the step resource container for a training step. This is used - // in current TF runtime. For tfrt, it is used by fallback op handler. - virtual void StartStep() = 0; - // Destroy the step resource container for a training step. - virtual void EndStep() = 0; - // Block until all pending nodes are finished. virtual Status AsyncWait() = 0; @@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext { // already exists. virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; + // Find and return a added function by its name. + virtual const FunctionDef* FindFunctionDef(const string& name) const = 0; + + // Return the ParsedName of Host CPU device. + virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0; + + // Configure soft device placement policy. + virtual void SetAllowSoftPlacement(bool enable) = 0; + + // Configure device placement policy logging. + virtual void SetLogDevicePlacement(bool enable) = 0; + + // Sets the device placement policy for the current thread. + virtual void SetThreadLocalDevicePlacementPolicy( + ContextDevicePlacementPolicy policy) = 0; + // Returns the device placement policy for the current thread. + virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0; + // For LLVM style RTTI. static bool classof(const AbstractContext* ptr) { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; } + //===--------------------------------------------------------------------===// + // Following are legacy features in TF Eager Runtime. + // TODO(tf-runtime): Figure out a way to deprecate following features after + // migrated to TFRT. + //===--------------------------------------------------------------------===// + // Clear pending nodes in thread executors and kernel caches. + virtual void ClearCachesAndThreadExecutors() = 0; + + // Initialize the step resource container for a training step. This is used + // in current TF runtime. For tfrt, it is used by fallback op handler. + virtual void StartStep() = 0; + // Destroy the step resource container for a training step. + virtual void EndStep() = 0; + + // Return the Eager Executor for current thread. Please note that Eager + // Executor is only used in current TF but not in TFRT. + virtual EagerExecutor& Executor() = 0; + // Update the Eager Executor for current thread. + virtual void SetExecutorForThread(EagerExecutor* executor) = 0; + + // Configure graph collection in RunMetadata. + virtual void SetShouldStoreGraphs(bool value) = 0; + protected: explicit ImmediateExecutionContext(AbstractContextKind kind) : AbstractContext(kind) {} diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 7362b4738b8..757ac1f7783 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -573,7 +573,7 @@ Status EagerContext::FindFunctionOpData( return func_lib_def_.LookUp(name, op_data); } -const FunctionDef* EagerContext::FindFunctionDef(const string& name) { +const FunctionDef* EagerContext::FindFunctionDef(const string& name) const { return func_lib_def_.Find(name); } diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index d0f1a6c4e78..f48da696d48 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -79,21 +79,6 @@ namespace eager { class RemoteMgr; } // namespace eager -// LINT.IfChange -// Note: Keep in sync with exported copy of enum in eager/c_api.h. -enum ContextDevicePlacementPolicy { - // Running operations with input tensors on the wrong device will fail. - DEVICE_PLACEMENT_EXPLICIT = 0, - // Copy the tensor to the right device but log a warning. - DEVICE_PLACEMENT_WARN = 1, - // Silently copy the tensor, which has a performance cost since the operation - // will be blocked till the copy completes. This is the default policy. - DEVICE_PLACEMENT_SILENT = 2, - // Placement policy which silently copies int32 tensors but not other dtypes. - DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, -}; -// LINT.ThenChange(//tensorflow/c/eager/c_api.h) - class RunMetadataListener { public: virtual ~RunMetadataListener() {} @@ -186,7 +171,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { std::function)>* runner() { return &runner_; } // Specify a executor for this thread. - void SetExecutorForThread(EagerExecutor* executor); + void SetExecutorForThread(EagerExecutor* executor) override; const std::shared_ptr> prioritized_device_type_list() const { @@ -195,15 +180,16 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { } // Clear pending nodes in thread executors and kernel caches. - void ClearCachesAndThreadExecutors(); + void ClearCachesAndThreadExecutors() override; // Clear pending nodes in default executor and kernel caches. void ClearCachesAndDefaultExecutor(); // Sets the device placement policy for the current thread. - void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy); + void SetThreadLocalDevicePlacementPolicy( + ContextDevicePlacementPolicy policy) override; // Returns the device placement policy for the current thread. - ContextDevicePlacementPolicy GetDevicePlacementPolicy() const; + ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override; // Select an appropriate device for an operation. // @@ -227,16 +213,19 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { Status FindFunctionOpData(const string& name, const tensorflow::OpRegistrationData** op_data); - const FunctionDef* FindFunctionDef(const string& name); + const FunctionDef* FindFunctionDef(const string& name) const override; Device* HostCPU() const { return host_cpu_device_; } Device* CanonicalDevice(Device* d) const { return HostCPU() == d ? nullptr : d; } + const DeviceNameUtils::ParsedName& HostCPUParsedName() const override { + return HostCPU()->parsed_name(); + } GraphCollector* GetGraphCollector() { return &graph_collector_; } - EagerExecutor& Executor(); + EagerExecutor& Executor() override; // Add the given `fdef` to the local FunctionLibraryDefinition. And add an // entry to the KernelAndDevice cache for it if it's not exist. @@ -267,9 +256,13 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); bool LogDevicePlacement() const { return log_device_placement_; } - void SetLogDevicePlacement(bool enable) { log_device_placement_ = enable; } + void SetLogDevicePlacement(bool enable) override { + log_device_placement_ = enable; + } bool AllowSoftPlacement() const { return allow_soft_placement_; } - void SetAllowSoftPlacement(bool enable) { allow_soft_placement_ = enable; } + void SetAllowSoftPlacement(bool enable) override { + allow_soft_placement_ = enable; + } bool LogMemory() const { return log_memory_; } Rendezvous* GetRendezvous() const { return rendezvous_; } @@ -316,7 +309,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // TODO(apassos) clean up RunMetadata storage. mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; } bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_); - void SetShouldStoreGraphs(bool value); + void SetShouldStoreGraphs(bool value) override; RunMetadata* RunMetadataProto() { return &run_metadata_; } void ClearRunMetadata() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_); diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 50c50ce3c42..766adf5eecc 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -263,6 +263,7 @@ cuda_py_test( size = "small", srcs = ["context_test.py"], python_version = "PY3", + tfrt_enabled = True, deps = [ ":context", ":test", diff --git a/tensorflow/python/eager/context_test.py b/tensorflow/python/eager/context_test.py index 086f943b3b0..fe86104cc0b 100644 --- a/tensorflow/python/eager/context_test.py +++ b/tensorflow/python/eager/context_test.py @@ -75,6 +75,7 @@ class ContextTest(test.TestCase): del tensor2 self.assertIs(weak_c(), None) + @test_util.disable_tfrt('b/169294215: tfrt does not support RunMetadata yet') def testSimpleGraphCollection(self): @def_function.function @@ -111,20 +112,24 @@ class ContextTest(test.TestCase): _ = context.get_function_def('this_should_not_be_found') @test_util.run_gpu_only + @test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported') def testGetMemoryUsage(self): array_ops.zeros([10]) # Allocate some memory on the GPU. self.assertGreater( context.context().get_total_memory_usage('GPU:0'), 0) + @test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported') def testGetMemoryUsageCPU(self): with self.assertRaisesRegex(ValueError, 'CPU does not support'): context.context().get_total_memory_usage('CPU:0') + @test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported') def testGetMemoryUsageUnknownDevice(self): with self.assertRaisesRegex(ValueError, 'Failed parsing device name'): context.context().get_total_memory_usage('unknown_device') @test_util.run_gpu_only + @test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported') def testGetMemoryUsageAmbiguousDevice(self): if len(context.context().list_physical_devices('GPU')) < 2: self.skipTest('Need at least 2 GPUs')