diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 7896f8d2414..96ff89f4a41 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -63,7 +63,7 @@ limitations under the License. // PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc. #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) #include "tensorflow/core/tfrt/eager/c_api_tfrt.h" -#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed.h" +#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h" #endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE #if !defined(IS_MOBILE_PLATFORM) @@ -120,7 +120,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { opts->async); #if !defined(IS_MOBILE_PLATFORM) tfrt_context->SetDistributedManager( - std::make_unique<tfrt::tf::DistributedManagerContextInterface>( + tfrt::tf::CreateDistributedManagerContext( tfrt_context->GetCoreRuntime()->GetHostContext())); #endif // !IS_MOBILE_PLATFORM return tensorflow::wrap(tfrt_context); diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc index dd3a0ec3521..bc55d0668aa 100644 --- a/tensorflow/core/distributed_runtime/remote_device.cc +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -149,4 +149,9 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, /*fail_fast=*/false, cb); } +std::unique_ptr<Device> NewRemoteDevice(Env* env, + DeviceAttributes device_attribute) { + return std::make_unique<RemoteDevice>(env, device_attribute); +} + } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/remote_device.h b/tensorflow/core/distributed_runtime/remote_device.h index cd53f8f4b9d..6d257b083f8 100644 --- a/tensorflow/core/distributed_runtime/remote_device.h +++ b/tensorflow/core/distributed_runtime/remote_device.h @@ -60,6 +60,9 @@ typedef std::function<void(const Status&, std::vector<Device*>*)> void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, const string& worker_name, NewRemoteDevicesDone done); +// Create Remote Device based on the given attributes. +std::unique_ptr<Device> NewRemoteDevice(Env* env, + DeviceAttributes device_attribute); } // namespace tensorflow #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_