Implement more context methods for tfrt and turn on context_test.

PiperOrigin-RevId: 334035897
Change-Id: I5ae4d51ed9e9c4777095cba781be5ec80f8555f8
This commit is contained in:
Xiao Yu 2020-09-27 12:29:52 -07:00 committed by TensorFlower Gardener
parent 85ba5dd127
commit 517e851ee1
7 changed files with 95 additions and 72 deletions

View File

@ -904,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalDevicePlacementPolicy(
tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
}
@ -915,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
// safe to call this function from the async EagerExecutor threads.
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextDevicePlacementPolicy>(
context->GetDevicePlacementPolicy());
tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
}
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
@ -1429,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
}
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->FindFunctionDef(name) != nullptr;
return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
}
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
}
} // extern "C"

View File

@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
}
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
}
uint64_t TFE_GetContextId(TFE_Context* ctx) {
@ -544,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
}
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetExecutorForThread(executor->executor());
tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
}
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return new TFE_Executor(&context->Executor());
return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
}
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
context->HostCPU()->parsed_name());
tensorflow::unwrap(ctx)->HostCPUParsedName());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
@ -572,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto* function_def = context->FindFunctionDef(function_name);
auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name);
@ -643,14 +631,10 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetAllowSoftPlacement(enable);
tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
}
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetLogDevicePlacement(enable);
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
}

View File

@ -29,8 +29,25 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
class EagerExecutor;
// LINT.IfChange
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
enum ContextDevicePlacementPolicy {
// Running operations with input tensors on the wrong device will fail.
DEVICE_PLACEMENT_EXPLICIT = 0,
// Copy the tensor to the right device but log a warning.
DEVICE_PLACEMENT_WARN = 1,
// Silently copy the tensor, which has a performance cost since the operation
// will be blocked till the copy completes. This is the default policy.
DEVICE_PLACEMENT_SILENT = 2,
// Placement policy which silently copies int32 tensors but not other dtypes.
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
};
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
// Abstract interface to a context.
//
@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext {
// List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Block until all pending nodes are finished.
virtual Status AsyncWait() = 0;
@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext {
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Find and return a added function by its name.
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
// Return the ParsedName of Host CPU device.
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
// Configure soft device placement policy.
virtual void SetAllowSoftPlacement(bool enable) = 0;
// Configure device placement policy logging.
virtual void SetLogDevicePlacement(bool enable) = 0;
// Sets the device placement policy for the current thread.
virtual void SetThreadLocalDevicePlacementPolicy(
ContextDevicePlacementPolicy policy) = 0;
// Returns the device placement policy for the current thread.
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
}
//===--------------------------------------------------------------------===//
// Following are legacy features in TF Eager Runtime.
// TODO(tf-runtime): Figure out a way to deprecate following features after
// migrated to TFRT.
//===--------------------------------------------------------------------===//
// Clear pending nodes in thread executors and kernel caches.
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Return the Eager Executor for current thread. Please note that Eager
// Executor is only used in current TF but not in TFRT.
virtual EagerExecutor& Executor() = 0;
// Update the Eager Executor for current thread.
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
// Configure graph collection in RunMetadata.
virtual void SetShouldStoreGraphs(bool value) = 0;
protected:
explicit ImmediateExecutionContext(AbstractContextKind kind)
: AbstractContext(kind) {}

View File

@ -573,7 +573,7 @@ Status EagerContext::FindFunctionOpData(
return func_lib_def_.LookUp(name, op_data);
}
const FunctionDef* EagerContext::FindFunctionDef(const string& name) {
const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
return func_lib_def_.Find(name);
}

View File

@ -79,21 +79,6 @@ namespace eager {
class RemoteMgr;
} // namespace eager
// LINT.IfChange
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
enum ContextDevicePlacementPolicy {
// Running operations with input tensors on the wrong device will fail.
DEVICE_PLACEMENT_EXPLICIT = 0,
// Copy the tensor to the right device but log a warning.
DEVICE_PLACEMENT_WARN = 1,
// Silently copy the tensor, which has a performance cost since the operation
// will be blocked till the copy completes. This is the default policy.
DEVICE_PLACEMENT_SILENT = 2,
// Placement policy which silently copies int32 tensors but not other dtypes.
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
};
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
class RunMetadataListener {
public:
virtual ~RunMetadataListener() {}
@ -186,7 +171,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
std::function<void(std::function<void()>)>* runner() { return &runner_; }
// Specify a executor for this thread.
void SetExecutorForThread(EagerExecutor* executor);
void SetExecutorForThread(EagerExecutor* executor) override;
const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list()
const {
@ -195,15 +180,16 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
}
// Clear pending nodes in thread executors and kernel caches.
void ClearCachesAndThreadExecutors();
void ClearCachesAndThreadExecutors() override;
// Clear pending nodes in default executor and kernel caches.
void ClearCachesAndDefaultExecutor();
// Sets the device placement policy for the current thread.
void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy);
void SetThreadLocalDevicePlacementPolicy(
ContextDevicePlacementPolicy policy) override;
// Returns the device placement policy for the current thread.
ContextDevicePlacementPolicy GetDevicePlacementPolicy() const;
ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override;
// Select an appropriate device for an operation.
//
@ -227,16 +213,19 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
Status FindFunctionOpData(const string& name,
const tensorflow::OpRegistrationData** op_data);
const FunctionDef* FindFunctionDef(const string& name);
const FunctionDef* FindFunctionDef(const string& name) const override;
Device* HostCPU() const { return host_cpu_device_; }
Device* CanonicalDevice(Device* d) const {
return HostCPU() == d ? nullptr : d;
}
const DeviceNameUtils::ParsedName& HostCPUParsedName() const override {
return HostCPU()->parsed_name();
}
GraphCollector* GetGraphCollector() { return &graph_collector_; }
EagerExecutor& Executor();
EagerExecutor& Executor() override;
// Add the given `fdef` to the local FunctionLibraryDefinition. And add an
// entry to the KernelAndDevice cache for it if it's not exist.
@ -267,9 +256,13 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
bool LogDevicePlacement() const { return log_device_placement_; }
void SetLogDevicePlacement(bool enable) { log_device_placement_ = enable; }
void SetLogDevicePlacement(bool enable) override {
log_device_placement_ = enable;
}
bool AllowSoftPlacement() const { return allow_soft_placement_; }
void SetAllowSoftPlacement(bool enable) { allow_soft_placement_ = enable; }
void SetAllowSoftPlacement(bool enable) override {
allow_soft_placement_ = enable;
}
bool LogMemory() const { return log_memory_; }
Rendezvous* GetRendezvous() const { return rendezvous_; }
@ -316,7 +309,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
// TODO(apassos) clean up RunMetadata storage.
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
void SetShouldStoreGraphs(bool value);
void SetShouldStoreGraphs(bool value) override;
RunMetadata* RunMetadataProto() { return &run_metadata_; }
void ClearRunMetadata() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);

View File

@ -263,6 +263,7 @@ cuda_py_test(
size = "small",
srcs = ["context_test.py"],
python_version = "PY3",
tfrt_enabled = True,
deps = [
":context",
":test",

View File

@ -75,6 +75,7 @@ class ContextTest(test.TestCase):
del tensor2
self.assertIs(weak_c(), None)
@test_util.disable_tfrt('b/169294215: tfrt does not support RunMetadata yet')
def testSimpleGraphCollection(self):
@def_function.function
@ -111,20 +112,24 @@ class ContextTest(test.TestCase):
_ = context.get_function_def('this_should_not_be_found')
@test_util.run_gpu_only
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryUsage(self):
array_ops.zeros([10]) # Allocate some memory on the GPU.
self.assertGreater(
context.context().get_total_memory_usage('GPU:0'), 0)
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryUsageCPU(self):
with self.assertRaisesRegex(ValueError, 'CPU does not support'):
context.context().get_total_memory_usage('CPU:0')
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryUsageUnknownDevice(self):
with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
context.context().get_total_memory_usage('unknown_device')
@test_util.run_gpu_only
@test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported')
def testGetMemoryUsageAmbiguousDevice(self):
if len(context.context().list_physical_devices('GPU')) < 2:
self.skipTest('Need at least 2 GPUs')