Parallel device: move broadcasting a bit earlier to simplify type signatures
I think we'll eventually need to move implicit broadcasting out of execute entirely for gradients to work, and moving it a little bit out helps simplify things for now. PiperOrigin-RevId: 315575773 Change-Id: Ib7e42e5f68d7261a431a4d0de01ca471090cd967
This commit is contained in:
parent
41a78aee2a
commit
0507043058
@ -40,6 +40,9 @@ using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
|||||||
using MaybeParallelTensorOwned =
|
using MaybeParallelTensorOwned =
|
||||||
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
||||||
|
|
||||||
|
using MaybeParallelTensorUnowned =
|
||||||
|
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||||
|
|
||||||
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
|
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
|
||||||
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
|
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
|
||||||
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
|
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
|
||||||
@ -141,9 +144,32 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
|
|||||||
result.emplace(std::move(result_content));
|
result.emplace(std::move(result_content));
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
std::vector<ParallelTensor*> parallel_inputs;
|
||||||
|
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
|
||||||
|
parallel_inputs.reserve(inputs.size());
|
||||||
|
implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
|
||||||
|
for (const auto& input : inputs) {
|
||||||
|
if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
|
||||||
|
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||||
|
// to each parallel operation.
|
||||||
|
//
|
||||||
|
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||||
|
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||||
|
// about things that are taken as inputs on the host or on their
|
||||||
|
// existing device (for multi-device functions).
|
||||||
|
std::unique_ptr<ParallelTensor> parallel_tensor(
|
||||||
|
parallel_device.CopyToParallelDevice(
|
||||||
|
context, absl::get<TFE_TensorHandle*>(input), status));
|
||||||
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
|
parallel_inputs.push_back(parallel_tensor.get());
|
||||||
|
implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
|
||||||
|
} else {
|
||||||
|
parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
|
||||||
|
}
|
||||||
|
}
|
||||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||||
maybe_parallel_results(
|
maybe_parallel_results(
|
||||||
parallel_device.Execute(context, std::move(inputs), operation_name,
|
parallel_device.Execute(context, parallel_inputs, operation_name,
|
||||||
attributes, expected_max_outputs, status));
|
attributes, expected_max_outputs, status));
|
||||||
if (!maybe_parallel_results.has_value()) return result;
|
if (!maybe_parallel_results.has_value()) return result;
|
||||||
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
||||||
|
@ -100,7 +100,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
|||||||
|
|
||||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||||
ParallelDevice::Execute(TFE_Context* context,
|
ParallelDevice::Execute(TFE_Context* context,
|
||||||
std::vector<MaybeParallelTensorUnowned> inputs,
|
const std::vector<ParallelTensor*>& inputs,
|
||||||
const char* operation_name,
|
const char* operation_name,
|
||||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||||
TF_Status* status) const {
|
TF_Status* status) const {
|
||||||
@ -129,27 +129,11 @@ ParallelDevice::Execute(TFE_Context* context,
|
|||||||
status);
|
status);
|
||||||
TFE_OpAddAttrs(op.get(), attributes);
|
TFE_OpAddAttrs(op.get(), attributes);
|
||||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
|
||||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
|
||||||
// to each parallel operation.
|
|
||||||
//
|
|
||||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
|
||||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
|
||||||
// about things that are taken as inputs on the host or on their
|
|
||||||
// existing device (for multi-device functions).
|
|
||||||
TFE_OpAddInput(op.get(),
|
|
||||||
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
|
||||||
status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return result;
|
|
||||||
} else {
|
|
||||||
// Parallel tensors are divided between operations by device.
|
// Parallel tensors are divided between operations by device.
|
||||||
TFE_OpAddInput(op.get(),
|
TFE_OpAddInput(op.get(), inputs[input_index]->tensor(device_index),
|
||||||
absl::get<ParallelTensor*>(inputs[input_index])
|
|
||||||
->tensor(device_index),
|
|
||||||
status);
|
status);
|
||||||
if (TF_GetCode(status) != TF_OK) return result;
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||||
int real_num_outputs = expected_max_outputs;
|
int real_num_outputs = expected_max_outputs;
|
||||||
// For nested devices, the inner device sees the async executor we've
|
// For nested devices, the inner device sees the async executor we've
|
||||||
|
@ -52,9 +52,6 @@ using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
|||||||
|
|
||||||
class ParallelTensor;
|
class ParallelTensor;
|
||||||
|
|
||||||
using MaybeParallelTensorUnowned =
|
|
||||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
|
||||||
|
|
||||||
// Forwards operations to `devices`, maintaining ParallelTensor with components
|
// Forwards operations to `devices`, maintaining ParallelTensor with components
|
||||||
// placed on each underlying device.
|
// placed on each underlying device.
|
||||||
class ParallelDevice {
|
class ParallelDevice {
|
||||||
@ -79,10 +76,9 @@ class ParallelDevice {
|
|||||||
|
|
||||||
// Takes a description of a single operation being executed on the
|
// Takes a description of a single operation being executed on the
|
||||||
// ParallelDevice, and in turn runs one operation per component device with
|
// ParallelDevice, and in turn runs one operation per component device with
|
||||||
// its corresponding inputs from the input ParallelTensors (or
|
// its corresponding inputs from the input ParallelTensors. Wraps the
|
||||||
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
// resulting per-device and per-output TFE_TensorHandles into one
|
||||||
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
// ParallelTensor per output of the original operation.
|
||||||
// output of the original operation.
|
|
||||||
//
|
//
|
||||||
// Attributes are forwarded to executed operations unmodified.
|
// Attributes are forwarded to executed operations unmodified.
|
||||||
//
|
//
|
||||||
@ -90,7 +86,7 @@ class ParallelDevice {
|
|||||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||||
// if sanity checks on dtypes/metadata fail.
|
// if sanity checks on dtypes/metadata fail.
|
||||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
|
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
|
||||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
|
||||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||||
int expected_max_outputs, TF_Status* status) const;
|
int expected_max_outputs, TF_Status* status) const;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user