STT-tensorflow/tensorflow/lite/delegates/flex/delegate_data.cc
Gaurav Jain e75d8dc058 Add mirroring for remote tensor handles
When executing on a remote worker, we may have to copy the TensorHandle
for each executed op. To avoid duplicated work, we expand the
TensorHandle to keep track of mirrors which are tied to the lifetime of
the TensorHandle. If a mirror already exists on a remote worker, no
additional copy is needed.

The change consists of the following:
- Add map of remote mirrors in TensorHandle.
- Add `mirror` boolean argument to EagerCopyToDevice which indicates to try
  configuring a mirror if possible.
- Add Device argument to RemoteAddress to handle mirrors.
- Expose a ContextMirroringPolicy for the EagerContext. We plan to add
  additional policies in the future, such as local tensor mirroring.
- Rename ContextDevicePlacementPolicy variables to be consistent with
  ContextMirroringPolicy.

PiperOrigin-RevId: 253945140
2019-06-19 00:39:29 -07:00

57 lines
2.0 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/flex/delegate_data.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/lib/core/status.h"
namespace tflite {
namespace flex {
DelegateData::DelegateData() {}
DelegateData::~DelegateData() {
if (eager_context_) eager_context_->Unref();
}
tensorflow::Status DelegateData::Prepare(
const tensorflow::SessionOptions& session_options) {
if (eager_context_) {
return tensorflow::Status();
}
std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
auto device_mgr =
absl::make_unique<tensorflow::DeviceMgr>(std::move(devices));
// Note that Rendezvous is ref-counted so it will be automatically deleted.
tensorflow::Rendezvous* rendezvous =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
eager_context_ = new tensorflow::EagerContext(
session_options,
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
/*async=*/false, device_mgr.release(), /*device_mgr_owned*/ true,
rendezvous, nullptr);
return tensorflow::Status();
}
} // namespace flex
} // namespace tflite