Parallel device: move broadcasting a bit earlier to simplify type signatures
I think we'll eventually need to move implicit broadcasting out of execute entirely for gradients to work, and moving it a little bit out helps simplify things for now. PiperOrigin-RevId: 315575773 Change-Id: Ib7e42e5f68d7261a431a4d0de01ca471090cd967
This commit is contained in:
parent
41a78aee2a
commit
0507043058
@ -40,6 +40,9 @@ using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
using MaybeParallelTensorOwned =
|
||||
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
||||
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
|
||||
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
|
||||
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
|
||||
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
|
||||
@ -141,9 +144,32 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
std::vector<ParallelTensor*> parallel_inputs;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
|
||||
parallel_inputs.reserve(inputs.size());
|
||||
implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
|
||||
for (const auto& input : inputs) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
std::unique_ptr<ParallelTensor> parallel_tensor(
|
||||
parallel_device.CopyToParallelDevice(
|
||||
context, absl::get<TFE_TensorHandle*>(input), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
parallel_inputs.push_back(parallel_tensor.get());
|
||||
implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
|
||||
} else {
|
||||
parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
|
||||
}
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
parallel_device.Execute(context, std::move(inputs), operation_name,
|
||||
parallel_device.Execute(context, parallel_inputs, operation_name,
|
||||
attributes, expected_max_outputs, status));
|
||||
if (!maybe_parallel_results.has_value()) return result;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
||||
|
@ -100,7 +100,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::Execute(TFE_Context* context,
|
||||
std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
TF_Status* status) const {
|
||||
@ -129,26 +129,10 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
status);
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
} else {
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<ParallelTensor*>(inputs[input_index])
|
||||
->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(), inputs[input_index]->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
|
@ -52,9 +52,6 @@ using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
class ParallelTensor;
|
||||
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
|
||||
// Forwards operations to `devices`, maintaining ParallelTensor with components
|
||||
// placed on each underlying device.
|
||||
class ParallelDevice {
|
||||
@ -79,10 +76,9 @@ class ParallelDevice {
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
||||
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
||||
// output of the original operation.
|
||||
// its corresponding inputs from the input ParallelTensors. Wraps the
|
||||
// resulting per-device and per-output TFE_TensorHandles into one
|
||||
// ParallelTensor per output of the original operation.
|
||||
//
|
||||
// Attributes are forwarded to executed operations unmodified.
|
||||
//
|
||||
@ -90,7 +86,7 @@ class ParallelDevice {
|
||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||
// if sanity checks on dtypes/metadata fail.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user