[rollback]Add ordering_token input to collective to leverage AutoControlDependency to order them
PiperOrigin-RevId: 338381814 Change-Id: I734f6ba4dd246f1df53c7e59b48b040147309161
This commit is contained in:
parent
4e1e1499fe
commit
6c95db9c14
tensorflow
c/eager/parallel_device
compiler/mlir/tensorflow/ir
core
common_runtime
graph
kernels
ops
python
tools/api/golden
@ -398,7 +398,6 @@ TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
|
||||
TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
|
||||
|
||||
TFE_OpAddInput(op.get(), input, status);
|
||||
TFE_OpAddInputList(op.get(), nullptr, 0, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
@ -489,7 +488,6 @@ void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||
final_op.length());
|
||||
TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
|
||||
TF_AddInput(reduce_desc, x);
|
||||
TF_AddInputList(reduce_desc, nullptr, 0);
|
||||
TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TF_Operation* operations[]{placeholder_op, reduce_op};
|
||||
|
@ -1645,12 +1645,10 @@ Mutually reduces multiple tensors of identical type and shape.
|
||||
TF_Int32Tensor:$group_size,
|
||||
TF_Int32Tensor:$group_key,
|
||||
TF_Int32Tensor:$instance_key,
|
||||
Variadic<TF_ResourceTensor>:$ordering_token,
|
||||
|
||||
TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op,
|
||||
TF_AnyStrAttrOf<["Id", "Div"]>:$final_op,
|
||||
DefaultValuedAttr<StrAttr, "auto">:$communication_hint,
|
||||
DefaultValuedAttr<F32Attr, "0.0f">:$timeout_seconds
|
||||
DefaultValuedAttr<StrAttr, "auto">:$communication_hint
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@ -1658,7 +1656,6 @@ Mutually reduces multiple tensors of identical type and shape.
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandSizeAttr Nordering_token = TF_DerivedOperandSizeAttr<4>;
|
||||
}
|
||||
|
||||
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
|
||||
@ -15010,4 +15007,4 @@ execution the transfer corresponds to.}]>:$dynamic_key,
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
|
||||
}
|
||||
}
|
||||
|
@ -378,12 +378,6 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
|
||||
} else if (!input_def.type_attr().empty() &&
|
||||
!input_def.number_attr().empty()) {
|
||||
InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs);
|
||||
} else if (!input_def.number_attr().empty()) {
|
||||
if (inference_attrs_.find(input_def.number_attr()) ==
|
||||
inference_attrs_.end()) {
|
||||
MutableAttrs()->Set(input_def.number_attr(), num_inputs);
|
||||
inference_attrs_.insert(input_def.number_attr());
|
||||
}
|
||||
} else {
|
||||
return errors::InvalidArgument("Invalid input list definition");
|
||||
}
|
||||
|
@ -360,7 +360,6 @@ class RingGathererTest : public ::testing::Test {
|
||||
.Attr("instance_key", params.instance.instance_key)
|
||||
.Attr("shape", params.instance.shape)
|
||||
.Input(FakeInput(params.instance.data_type))
|
||||
.Input(std::vector<NodeDefBuilder::NodeOut>())
|
||||
.Finalize(&node_def));
|
||||
return GetKernel(node_def, device_type, device);
|
||||
}
|
||||
|
@ -385,7 +385,6 @@ class RingReducerTest : public ::testing::Test {
|
||||
.Attr("instance_key", params.instance.instance_key)
|
||||
.Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets)
|
||||
.Input(FakeInput(params.instance.data_type))
|
||||
.Input(std::vector<NodeDefBuilder::NodeOut>())
|
||||
.Finalize(&node_def));
|
||||
return GetKernel(node_def, device_type, device);
|
||||
}
|
||||
|
@ -79,20 +79,19 @@ void VerifyAttrs(
|
||||
Node* CollectiveReduceNode(GraphDefBuilder* builder, Node* input,
|
||||
const string& name, const string& device,
|
||||
int instance_key) {
|
||||
auto opts = builder->opts()
|
||||
.WithName(name)
|
||||
.WithDevice(device)
|
||||
.WithAttr("T", DT_FLOAT)
|
||||
.WithAttr("group_size", 2)
|
||||
.WithAttr("group_key", 1)
|
||||
.WithAttr("instance_key", instance_key)
|
||||
.WithAttr("merge_op", "Add")
|
||||
.WithAttr("final_op", "Id")
|
||||
.WithAttr("subdiv_offsets", {1});
|
||||
NodeBuilder node_builder(opts.GetNameForOp("CollectiveReduce"),
|
||||
"CollectiveReduce", opts.op_registry());
|
||||
node_builder.Input(input).Input(std::vector<NodeBuilder::NodeOut>());
|
||||
return opts.FinalizeBuilder(&node_builder);
|
||||
Node* collective_node =
|
||||
ops::UnaryOp("CollectiveReduce", input,
|
||||
builder->opts()
|
||||
.WithName(name)
|
||||
.WithDevice(device)
|
||||
.WithAttr("T", DT_FLOAT)
|
||||
.WithAttr("group_size", 2)
|
||||
.WithAttr("group_key", 1)
|
||||
.WithAttr("instance_key", instance_key)
|
||||
.WithAttr("merge_op", "Add")
|
||||
.WithAttr("final_op", "Id")
|
||||
.WithAttr("subdiv_offsets", {1}));
|
||||
return collective_node;
|
||||
}
|
||||
|
||||
// Initialize the following graph:
|
||||
|
@ -237,7 +237,6 @@ class NcclTestBase : public ::testing::Test {
|
||||
.Attr("instance_key", params.instance.instance_key)
|
||||
.Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets)
|
||||
.Input(FakeInput(params.instance.data_type))
|
||||
.Input(std::vector<NodeDefBuilder::NodeOut>())
|
||||
.Finalize(&node_def));
|
||||
return GetKernel(node_def, device);
|
||||
}
|
||||
|
@ -21,7 +21,6 @@ namespace tensorflow {
|
||||
|
||||
REGISTER_OP("CollectiveReduce")
|
||||
.Input("input: T")
|
||||
.Input("ordering_token: Nordering_token * resource")
|
||||
.Output("data: T")
|
||||
.Attr("T: {float, float16, float64, int32, int64}")
|
||||
.Attr("group_size: int")
|
||||
@ -33,13 +32,11 @@ REGISTER_OP("CollectiveReduce")
|
||||
.Attr("wait_for: list(int) = []")
|
||||
.Attr("communication_hint: string = 'auto'")
|
||||
.Attr("timeout_seconds: float = 0")
|
||||
.Attr("Nordering_token: int >= 0 = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("CollectiveGather")
|
||||
.Input("input: T")
|
||||
.Input("ordering_token: Nordering_token * resource")
|
||||
.Output("data: T")
|
||||
.Attr("T: {float, float16, float64, int32, int64}")
|
||||
.Attr("group_size: int")
|
||||
@ -48,7 +45,6 @@ REGISTER_OP("CollectiveGather")
|
||||
.Attr("shape: shape")
|
||||
.Attr("communication_hint: string = 'auto'")
|
||||
.Attr("timeout_seconds: float = 0")
|
||||
.Attr("Nordering_token: int >= 0 = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
// Scalar input is not supported.
|
||||
@ -115,12 +111,10 @@ REGISTER_OP("CollectiveReduceV2")
|
||||
.Input("group_size: int32")
|
||||
.Input("group_key: int32")
|
||||
.Input("instance_key: int32")
|
||||
.Input("ordering_token: Nordering_token * resource")
|
||||
.Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}")
|
||||
.Attr("final_op: {'Id', 'Div'}")
|
||||
.Attr("communication_hint: string = 'auto'")
|
||||
.Attr("timeout_seconds: float = 0")
|
||||
.Attr("Nordering_token: int >= 0 = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
@ -131,10 +125,8 @@ REGISTER_OP("CollectiveGatherV2")
|
||||
.Input("group_size: int32")
|
||||
.Input("group_key: int32")
|
||||
.Input("instance_key: int32")
|
||||
.Input("ordering_token: Nordering_token * resource")
|
||||
.Attr("communication_hint: string = 'auto'")
|
||||
.Attr("timeout_seconds: float = 0")
|
||||
.Attr("Nordering_token: int >= 0 = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
// Scalar input is not supported.
|
||||
|
@ -142,71 +142,3 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "CollectiveGather"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "ordering_token"
|
||||
type: DT_RESOURCE
|
||||
number_attr: "Nordering_token"
|
||||
}
|
||||
output_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_HALF
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "group_size"
|
||||
type: "int"
|
||||
}
|
||||
attr {
|
||||
name: "group_key"
|
||||
type: "int"
|
||||
}
|
||||
attr {
|
||||
name: "instance_key"
|
||||
type: "int"
|
||||
}
|
||||
attr {
|
||||
name: "shape"
|
||||
type: "shape"
|
||||
}
|
||||
attr {
|
||||
name: "communication_hint"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "auto"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "timeout_seconds"
|
||||
type: "float"
|
||||
default_value {
|
||||
f: 0
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Nordering_token"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
has_minimum: true
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
|
@ -49,67 +49,3 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "CollectiveGatherV2"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "group_size"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "group_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "instance_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "ordering_token"
|
||||
type: DT_RESOURCE
|
||||
number_attr: "Nordering_token"
|
||||
}
|
||||
output_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_HALF
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "communication_hint"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "auto"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "timeout_seconds"
|
||||
type: "float"
|
||||
default_value {
|
||||
f: 0
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Nordering_token"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
has_minimum: true
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
|
@ -295,101 +295,3 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "CollectiveReduce"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "ordering_token"
|
||||
type: DT_RESOURCE
|
||||
number_attr: "Nordering_token"
|
||||
}
|
||||
output_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_HALF
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "group_size"
|
||||
type: "int"
|
||||
}
|
||||
attr {
|
||||
name: "group_key"
|
||||
type: "int"
|
||||
}
|
||||
attr {
|
||||
name: "instance_key"
|
||||
type: "int"
|
||||
}
|
||||
attr {
|
||||
name: "merge_op"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "Min"
|
||||
s: "Max"
|
||||
s: "Mul"
|
||||
s: "Add"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "final_op"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "Id"
|
||||
s: "Div"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "subdiv_offsets"
|
||||
type: "list(int)"
|
||||
}
|
||||
attr {
|
||||
name: "wait_for"
|
||||
type: "list(int)"
|
||||
default_value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "communication_hint"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "auto"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "timeout_seconds"
|
||||
type: "float"
|
||||
default_value {
|
||||
f: 0
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Nordering_token"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
has_minimum: true
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
|
@ -137,89 +137,3 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "CollectiveReduceV2"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "group_size"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "group_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "instance_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "ordering_token"
|
||||
type: DT_RESOURCE
|
||||
number_attr: "Nordering_token"
|
||||
}
|
||||
output_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_HALF
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "merge_op"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "Min"
|
||||
s: "Max"
|
||||
s: "Mul"
|
||||
s: "Add"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "final_op"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "Id"
|
||||
s: "Div"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "communication_hint"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "auto"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "timeout_seconds"
|
||||
type: "float"
|
||||
default_value {
|
||||
f: 0
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Nordering_token"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
has_minimum: true
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
|
@ -7502,11 +7502,6 @@ op {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "ordering_token"
|
||||
type: DT_RESOURCE
|
||||
number_attr: "Nordering_token"
|
||||
}
|
||||
output_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
@ -7554,14 +7549,6 @@ op {
|
||||
f: 0
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Nordering_token"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
has_minimum: true
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
@ -7582,11 +7569,6 @@ op {
|
||||
name: "instance_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "ordering_token"
|
||||
type: DT_RESOURCE
|
||||
number_attr: "Nordering_token"
|
||||
}
|
||||
output_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
@ -7618,14 +7600,6 @@ op {
|
||||
f: 0
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Nordering_token"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
has_minimum: true
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
@ -7674,11 +7648,6 @@ op {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "ordering_token"
|
||||
type: DT_RESOURCE
|
||||
number_attr: "Nordering_token"
|
||||
}
|
||||
output_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
@ -7756,14 +7725,6 @@ op {
|
||||
f: 0
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Nordering_token"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
has_minimum: true
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
@ -7784,11 +7745,6 @@ op {
|
||||
name: "instance_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "ordering_token"
|
||||
type: DT_RESOURCE
|
||||
number_attr: "Nordering_token"
|
||||
}
|
||||
output_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
@ -7842,14 +7798,6 @@ op {
|
||||
f: 0
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Nordering_token"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
has_minimum: true
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
|
@ -31,12 +31,10 @@ from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops as _collective_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -70,19 +68,6 @@ device_combination = (
|
||||
device='GPU', communication=['RING', 'NCCL'], required_gpus=2))
|
||||
|
||||
|
||||
collective_op_combinations = combinations.times(
|
||||
combinations.combine(
|
||||
collective_op=[
|
||||
combinations.NamedObject('all_reduce', CollectiveOpsV1.all_reduce),
|
||||
combinations.NamedObject('all_reduce_v2',
|
||||
CollectiveOpsV2.all_reduce),
|
||||
combinations.NamedObject('all_gather', CollectiveOpsV1.all_gather),
|
||||
combinations.NamedObject('all_gather_v2',
|
||||
CollectiveOpsV2.all_gather),
|
||||
],
|
||||
mode='eager'), device_combination)
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(
|
||||
@ -298,7 +283,20 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
run_and_assert(group_size=3, group_key=2)
|
||||
|
||||
|
||||
@combinations.generate(collective_op_combinations)
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(
|
||||
collective_op=[
|
||||
combinations.NamedObject('all_reduce',
|
||||
CollectiveOpsV1.all_reduce),
|
||||
combinations.NamedObject('all_reduce_v2',
|
||||
CollectiveOpsV2.all_reduce),
|
||||
combinations.NamedObject('all_gather',
|
||||
CollectiveOpsV1.all_gather),
|
||||
combinations.NamedObject('all_gather_v2',
|
||||
CollectiveOpsV2.all_gather),
|
||||
],
|
||||
mode='eager'), device_combination))
|
||||
class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -551,7 +549,20 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
communication_hint=communication)
|
||||
|
||||
|
||||
@combinations.generate(collective_op_combinations)
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(
|
||||
collective_op=[
|
||||
combinations.NamedObject('all_reduce',
|
||||
CollectiveOpsV1.all_reduce),
|
||||
combinations.NamedObject('all_reduce_v2',
|
||||
CollectiveOpsV2.all_reduce),
|
||||
combinations.NamedObject('all_gather',
|
||||
CollectiveOpsV1.all_gather),
|
||||
combinations.NamedObject('all_gather_v2',
|
||||
CollectiveOpsV2.all_gather),
|
||||
],
|
||||
mode='eager'), device_combination))
|
||||
class TimeoutTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -676,85 +687,6 @@ class TimeoutTest(test.TestCase, parameterized.TestCase):
|
||||
communication_hint=communication)
|
||||
|
||||
|
||||
@combinations.generate(collective_op_combinations)
|
||||
class OrderingTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
_setup_context()
|
||||
super().setUp()
|
||||
|
||||
def testOrdering(self, collective_op, device, communication):
|
||||
dev0 = '/device:%s:0' % device
|
||||
dev1 = '/device:%s:1' % device
|
||||
group_size = 2
|
||||
group_key = 100
|
||||
instance_key = 100
|
||||
in_tensor = constant_op.constant([1.])
|
||||
|
||||
with ops.device(dev0):
|
||||
token0 = resource_variable_ops.ResourceVariable(0.)
|
||||
with ops.device(dev1):
|
||||
token1 = resource_variable_ops.ResourceVariable(0.)
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
# Launch the first collective with token.
|
||||
with ops.device(dev0):
|
||||
collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
ordering_token=token0.handle)
|
||||
with ops.device(dev1):
|
||||
collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
ordering_token=token1.handle)
|
||||
# Launch the second collective without token.
|
||||
with ops.device(dev0):
|
||||
collective_op(in_tensor, group_size, group_key, instance_key)
|
||||
with ops.device(dev1):
|
||||
collective_op(in_tensor, group_size, group_key, instance_key)
|
||||
# Launch the third collective with token.
|
||||
with ops.device(dev0):
|
||||
collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
ordering_token=token0.handle)
|
||||
with ops.device(dev1):
|
||||
collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
ordering_token=token1.handle)
|
||||
|
||||
graph = f.get_concrete_function().graph
|
||||
for device in [dev0, dev1]:
|
||||
# Try to find the third collective, which should have the first collective
|
||||
# as a control input.
|
||||
third = None
|
||||
for op in graph.get_operations():
|
||||
if (op.type.startswith('Collective') and op.device.endswith(device) and
|
||||
op.control_inputs and
|
||||
op.control_inputs[0].type.startswith('Collective')):
|
||||
self.assertIsNone(third)
|
||||
third = op
|
||||
self.assertIsNotNone(third)
|
||||
# Verify it's not the second collective by looking at the inputs.
|
||||
self.assertTrue(any(v.dtype == dtypes.resource for v in third.inputs))
|
||||
first = third.control_inputs[0]
|
||||
self.assertEqual(third.device, first.device)
|
||||
# Verify it's not the second collective by looking at the inputs.
|
||||
self.assertTrue(any(v.dtype == dtypes.resource for v in first.inputs))
|
||||
self.assertEmpty(first.control_inputs)
|
||||
|
||||
|
||||
def _setup_context():
|
||||
context._reset_context()
|
||||
test_util.set_logical_devices_to_at_least('CPU', 4)
|
||||
|
@ -28,8 +28,7 @@ def all_reduce(t,
|
||||
final_op='Id',
|
||||
subdiv_offsets=(0,),
|
||||
communication_hint='auto',
|
||||
timeout=0,
|
||||
ordering_token=None):
|
||||
timeout=0):
|
||||
"""Reduces tensors collectively, across devices.
|
||||
|
||||
Args:
|
||||
@ -51,9 +50,6 @@ def all_reduce(t,
|
||||
timeout: a float. If set to a non zero, set a completion timeout to detect
|
||||
staleness. If the timer goes off, a DeadlineExceededError is raised. The
|
||||
timeout value in seconds. This feature is experimental.
|
||||
ordering_token: an optional resource tensor to pass to the op as inputs.
|
||||
They aren't used by the kernel but allow AutoControlDependency to order
|
||||
the collectives with control dependencies.
|
||||
|
||||
Returns:
|
||||
An Op implementing the distributed reduction.
|
||||
@ -63,8 +59,6 @@ def all_reduce(t,
|
||||
"""
|
||||
if group_size < 1:
|
||||
raise ValueError('Parameter group_size to all_reduce must be at least 1.')
|
||||
if ordering_token is not None:
|
||||
ordering_token = [ordering_token]
|
||||
return gen_collective_ops.collective_reduce(
|
||||
t,
|
||||
group_size=group_size,
|
||||
@ -74,8 +68,7 @@ def all_reduce(t,
|
||||
final_op=final_op,
|
||||
subdiv_offsets=subdiv_offsets,
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout,
|
||||
ordering_token=ordering_token or [])
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def all_reduce_v2(t,
|
||||
@ -85,8 +78,7 @@ def all_reduce_v2(t,
|
||||
merge_op='Add',
|
||||
final_op='Id',
|
||||
communication_hint='auto',
|
||||
timeout=0,
|
||||
ordering_token=None):
|
||||
timeout=0):
|
||||
"""Reduces tensors collectively, across devices.
|
||||
|
||||
Args:
|
||||
@ -106,15 +98,10 @@ def all_reduce_v2(t,
|
||||
timeout: a float. If set to a non zero, set a completion timeout to detect
|
||||
staleness. If the timer goes off, a DeadlineExceededError is raised. The
|
||||
timeout value in seconds. This feature is experimental.
|
||||
ordering_token: an optional resource tensor to pass to the op as inputs.
|
||||
They aren't used by the kernel but allow AutoControlDependency to order
|
||||
the collectives with control dependencies.
|
||||
|
||||
Returns:
|
||||
An Op implementing the distributed reduction.
|
||||
"""
|
||||
if ordering_token is not None:
|
||||
ordering_token = [ordering_token]
|
||||
return gen_collective_ops.collective_reduce_v2(
|
||||
t,
|
||||
group_size=group_size,
|
||||
@ -123,8 +110,7 @@ def all_reduce_v2(t,
|
||||
merge_op=merge_op,
|
||||
final_op=final_op,
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout,
|
||||
ordering_token=ordering_token or [])
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def all_gather(t,
|
||||
@ -132,8 +118,7 @@ def all_gather(t,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint='auto',
|
||||
timeout=0,
|
||||
ordering_token=None):
|
||||
timeout=0):
|
||||
"""Accumulates tensors collectively, across devices, along first dimension.
|
||||
|
||||
Args:
|
||||
@ -148,9 +133,6 @@ def all_gather(t,
|
||||
timeout: a float. If set to a non zero, set a completion timeout to detect
|
||||
staleness. If the timer goes off, a DeadlineExceededError is raised. The
|
||||
timeout value in seconds. This feature is experimental.
|
||||
ordering_token: an optional resource tensor to pass to the op as inputs.
|
||||
They aren't used by the kernel but allow AutoControlDependency to order
|
||||
the collectives with control dependencies.
|
||||
|
||||
Returns:
|
||||
An Op implementing the distributed operation.
|
||||
@ -160,8 +142,6 @@ def all_gather(t,
|
||||
"""
|
||||
if group_size < 1:
|
||||
raise ValueError('Parameter group_size to all_gather must be at least 1.')
|
||||
if ordering_token is not None:
|
||||
ordering_token = [ordering_token]
|
||||
return gen_collective_ops.collective_gather(
|
||||
t,
|
||||
shape=[0],
|
||||
@ -169,8 +149,7 @@ def all_gather(t,
|
||||
group_key=group_key,
|
||||
instance_key=instance_key,
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout,
|
||||
ordering_token=ordering_token or [])
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def all_gather_v2(t,
|
||||
@ -178,8 +157,7 @@ def all_gather_v2(t,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint='auto',
|
||||
timeout=0,
|
||||
ordering_token=None):
|
||||
timeout=0):
|
||||
"""Accumulates tensors collectively, across devices, along first dimension.
|
||||
|
||||
Args:
|
||||
@ -195,23 +173,17 @@ def all_gather_v2(t,
|
||||
timeout: a float. If set to a non zero, set a completion timeout to detect
|
||||
staleness. If the timer goes off, a DeadlineExceededError is raised. The
|
||||
timeout value in seconds. This feature is experimental.
|
||||
ordering_token: an optional resource tensor to pass to the op as inputs.
|
||||
They aren't used by the kernel but allow AutoControlDependency to order
|
||||
the collectives with control dependencies.
|
||||
|
||||
Returns:
|
||||
An Op implementing the distributed operation.
|
||||
"""
|
||||
if ordering_token is not None:
|
||||
ordering_token = [ordering_token]
|
||||
return gen_collective_ops.collective_gather_v2(
|
||||
t,
|
||||
group_size=group_size,
|
||||
group_key=group_key,
|
||||
instance_key=instance_key,
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout,
|
||||
ordering_token=ordering_token or [])
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def broadcast_send(t,
|
||||
|
@ -758,11 +758,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveGather"
|
||||
argspec: "args=[\'input\', \'ordering_token\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveGatherV2"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'ordering_token\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectivePermute"
|
||||
@ -770,11 +770,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveReduce"
|
||||
argspec: "args=[\'input\', \'ordering_token\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveReduceV2"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'ordering_token\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CombinedNonMaxSuppression"
|
||||
|
@ -758,11 +758,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveGather"
|
||||
argspec: "args=[\'input\', \'ordering_token\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveGatherV2"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'ordering_token\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectivePermute"
|
||||
@ -770,11 +770,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveReduce"
|
||||
argspec: "args=[\'input\', \'ordering_token\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveReduceV2"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'ordering_token\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CombinedNonMaxSuppression"
|
||||
|
Loading…
Reference in New Issue
Block a user