[rollback]Add ordering_token input to collective to leverage AutoControlDependency to order them
PiperOrigin-RevId: 338381814 Change-Id: I734f6ba4dd246f1df53c7e59b48b040147309161
This commit is contained in:
parent
4e1e1499fe
commit
6c95db9c14
tensorflow
c/eager/parallel_device
compiler/mlir/tensorflow/ir
core
common_runtime
graph
kernels
ops
python
tools/api/golden
@ -398,7 +398,6 @@ TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
|
|||||||
TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
|
TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
|
||||||
|
|
||||||
TFE_OpAddInput(op.get(), input, status);
|
TFE_OpAddInput(op.get(), input, status);
|
||||||
TFE_OpAddInputList(op.get(), nullptr, 0, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
|
||||||
TFE_TensorHandle* result_handle;
|
TFE_TensorHandle* result_handle;
|
||||||
@ -489,7 +488,6 @@ void RegisterCollectiveMulFunction(TFE_Context* context,
|
|||||||
final_op.length());
|
final_op.length());
|
||||||
TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
|
TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
|
||||||
TF_AddInput(reduce_desc, x);
|
TF_AddInput(reduce_desc, x);
|
||||||
TF_AddInputList(reduce_desc, nullptr, 0);
|
|
||||||
TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
|
TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
TF_Operation* operations[]{placeholder_op, reduce_op};
|
TF_Operation* operations[]{placeholder_op, reduce_op};
|
||||||
|
@ -1645,12 +1645,10 @@ Mutually reduces multiple tensors of identical type and shape.
|
|||||||
TF_Int32Tensor:$group_size,
|
TF_Int32Tensor:$group_size,
|
||||||
TF_Int32Tensor:$group_key,
|
TF_Int32Tensor:$group_key,
|
||||||
TF_Int32Tensor:$instance_key,
|
TF_Int32Tensor:$instance_key,
|
||||||
Variadic<TF_ResourceTensor>:$ordering_token,
|
|
||||||
|
|
||||||
TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op,
|
TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op,
|
||||||
TF_AnyStrAttrOf<["Id", "Div"]>:$final_op,
|
TF_AnyStrAttrOf<["Id", "Div"]>:$final_op,
|
||||||
DefaultValuedAttr<StrAttr, "auto">:$communication_hint,
|
DefaultValuedAttr<StrAttr, "auto">:$communication_hint
|
||||||
DefaultValuedAttr<F32Attr, "0.0f">:$timeout_seconds
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@ -1658,7 +1656,6 @@ Mutually reduces multiple tensors of identical type and shape.
|
|||||||
);
|
);
|
||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
TF_DerivedOperandSizeAttr Nordering_token = TF_DerivedOperandSizeAttr<4>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
|
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
|
||||||
|
@ -378,12 +378,6 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
|
|||||||
} else if (!input_def.type_attr().empty() &&
|
} else if (!input_def.type_attr().empty() &&
|
||||||
!input_def.number_attr().empty()) {
|
!input_def.number_attr().empty()) {
|
||||||
InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs);
|
InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs);
|
||||||
} else if (!input_def.number_attr().empty()) {
|
|
||||||
if (inference_attrs_.find(input_def.number_attr()) ==
|
|
||||||
inference_attrs_.end()) {
|
|
||||||
MutableAttrs()->Set(input_def.number_attr(), num_inputs);
|
|
||||||
inference_attrs_.insert(input_def.number_attr());
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return errors::InvalidArgument("Invalid input list definition");
|
return errors::InvalidArgument("Invalid input list definition");
|
||||||
}
|
}
|
||||||
|
@ -360,7 +360,6 @@ class RingGathererTest : public ::testing::Test {
|
|||||||
.Attr("instance_key", params.instance.instance_key)
|
.Attr("instance_key", params.instance.instance_key)
|
||||||
.Attr("shape", params.instance.shape)
|
.Attr("shape", params.instance.shape)
|
||||||
.Input(FakeInput(params.instance.data_type))
|
.Input(FakeInput(params.instance.data_type))
|
||||||
.Input(std::vector<NodeDefBuilder::NodeOut>())
|
|
||||||
.Finalize(&node_def));
|
.Finalize(&node_def));
|
||||||
return GetKernel(node_def, device_type, device);
|
return GetKernel(node_def, device_type, device);
|
||||||
}
|
}
|
||||||
|
@ -385,7 +385,6 @@ class RingReducerTest : public ::testing::Test {
|
|||||||
.Attr("instance_key", params.instance.instance_key)
|
.Attr("instance_key", params.instance.instance_key)
|
||||||
.Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets)
|
.Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets)
|
||||||
.Input(FakeInput(params.instance.data_type))
|
.Input(FakeInput(params.instance.data_type))
|
||||||
.Input(std::vector<NodeDefBuilder::NodeOut>())
|
|
||||||
.Finalize(&node_def));
|
.Finalize(&node_def));
|
||||||
return GetKernel(node_def, device_type, device);
|
return GetKernel(node_def, device_type, device);
|
||||||
}
|
}
|
||||||
|
@ -79,7 +79,9 @@ void VerifyAttrs(
|
|||||||
Node* CollectiveReduceNode(GraphDefBuilder* builder, Node* input,
|
Node* CollectiveReduceNode(GraphDefBuilder* builder, Node* input,
|
||||||
const string& name, const string& device,
|
const string& name, const string& device,
|
||||||
int instance_key) {
|
int instance_key) {
|
||||||
auto opts = builder->opts()
|
Node* collective_node =
|
||||||
|
ops::UnaryOp("CollectiveReduce", input,
|
||||||
|
builder->opts()
|
||||||
.WithName(name)
|
.WithName(name)
|
||||||
.WithDevice(device)
|
.WithDevice(device)
|
||||||
.WithAttr("T", DT_FLOAT)
|
.WithAttr("T", DT_FLOAT)
|
||||||
@ -88,11 +90,8 @@ Node* CollectiveReduceNode(GraphDefBuilder* builder, Node* input,
|
|||||||
.WithAttr("instance_key", instance_key)
|
.WithAttr("instance_key", instance_key)
|
||||||
.WithAttr("merge_op", "Add")
|
.WithAttr("merge_op", "Add")
|
||||||
.WithAttr("final_op", "Id")
|
.WithAttr("final_op", "Id")
|
||||||
.WithAttr("subdiv_offsets", {1});
|
.WithAttr("subdiv_offsets", {1}));
|
||||||
NodeBuilder node_builder(opts.GetNameForOp("CollectiveReduce"),
|
return collective_node;
|
||||||
"CollectiveReduce", opts.op_registry());
|
|
||||||
node_builder.Input(input).Input(std::vector<NodeBuilder::NodeOut>());
|
|
||||||
return opts.FinalizeBuilder(&node_builder);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the following graph:
|
// Initialize the following graph:
|
||||||
|
@ -237,7 +237,6 @@ class NcclTestBase : public ::testing::Test {
|
|||||||
.Attr("instance_key", params.instance.instance_key)
|
.Attr("instance_key", params.instance.instance_key)
|
||||||
.Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets)
|
.Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets)
|
||||||
.Input(FakeInput(params.instance.data_type))
|
.Input(FakeInput(params.instance.data_type))
|
||||||
.Input(std::vector<NodeDefBuilder::NodeOut>())
|
|
||||||
.Finalize(&node_def));
|
.Finalize(&node_def));
|
||||||
return GetKernel(node_def, device);
|
return GetKernel(node_def, device);
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,6 @@ namespace tensorflow {
|
|||||||
|
|
||||||
REGISTER_OP("CollectiveReduce")
|
REGISTER_OP("CollectiveReduce")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Input("ordering_token: Nordering_token * resource")
|
|
||||||
.Output("data: T")
|
.Output("data: T")
|
||||||
.Attr("T: {float, float16, float64, int32, int64}")
|
.Attr("T: {float, float16, float64, int32, int64}")
|
||||||
.Attr("group_size: int")
|
.Attr("group_size: int")
|
||||||
@ -33,13 +32,11 @@ REGISTER_OP("CollectiveReduce")
|
|||||||
.Attr("wait_for: list(int) = []")
|
.Attr("wait_for: list(int) = []")
|
||||||
.Attr("communication_hint: string = 'auto'")
|
.Attr("communication_hint: string = 'auto'")
|
||||||
.Attr("timeout_seconds: float = 0")
|
.Attr("timeout_seconds: float = 0")
|
||||||
.Attr("Nordering_token: int >= 0 = 0")
|
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(shape_inference::UnchangedShape);
|
.SetShapeFn(shape_inference::UnchangedShape);
|
||||||
|
|
||||||
REGISTER_OP("CollectiveGather")
|
REGISTER_OP("CollectiveGather")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Input("ordering_token: Nordering_token * resource")
|
|
||||||
.Output("data: T")
|
.Output("data: T")
|
||||||
.Attr("T: {float, float16, float64, int32, int64}")
|
.Attr("T: {float, float16, float64, int32, int64}")
|
||||||
.Attr("group_size: int")
|
.Attr("group_size: int")
|
||||||
@ -48,7 +45,6 @@ REGISTER_OP("CollectiveGather")
|
|||||||
.Attr("shape: shape")
|
.Attr("shape: shape")
|
||||||
.Attr("communication_hint: string = 'auto'")
|
.Attr("communication_hint: string = 'auto'")
|
||||||
.Attr("timeout_seconds: float = 0")
|
.Attr("timeout_seconds: float = 0")
|
||||||
.Attr("Nordering_token: int >= 0 = 0")
|
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
// Scalar input is not supported.
|
// Scalar input is not supported.
|
||||||
@ -115,12 +111,10 @@ REGISTER_OP("CollectiveReduceV2")
|
|||||||
.Input("group_size: int32")
|
.Input("group_size: int32")
|
||||||
.Input("group_key: int32")
|
.Input("group_key: int32")
|
||||||
.Input("instance_key: int32")
|
.Input("instance_key: int32")
|
||||||
.Input("ordering_token: Nordering_token * resource")
|
|
||||||
.Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}")
|
.Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}")
|
||||||
.Attr("final_op: {'Id', 'Div'}")
|
.Attr("final_op: {'Id', 'Div'}")
|
||||||
.Attr("communication_hint: string = 'auto'")
|
.Attr("communication_hint: string = 'auto'")
|
||||||
.Attr("timeout_seconds: float = 0")
|
.Attr("timeout_seconds: float = 0")
|
||||||
.Attr("Nordering_token: int >= 0 = 0")
|
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn(shape_inference::UnchangedShape);
|
.SetShapeFn(shape_inference::UnchangedShape);
|
||||||
|
|
||||||
@ -131,10 +125,8 @@ REGISTER_OP("CollectiveGatherV2")
|
|||||||
.Input("group_size: int32")
|
.Input("group_size: int32")
|
||||||
.Input("group_key: int32")
|
.Input("group_key: int32")
|
||||||
.Input("instance_key: int32")
|
.Input("instance_key: int32")
|
||||||
.Input("ordering_token: Nordering_token * resource")
|
|
||||||
.Attr("communication_hint: string = 'auto'")
|
.Attr("communication_hint: string = 'auto'")
|
||||||
.Attr("timeout_seconds: float = 0")
|
.Attr("timeout_seconds: float = 0")
|
||||||
.Attr("Nordering_token: int >= 0 = 0")
|
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
// Scalar input is not supported.
|
// Scalar input is not supported.
|
||||||
|
@ -142,71 +142,3 @@ op {
|
|||||||
}
|
}
|
||||||
is_stateful: true
|
is_stateful: true
|
||||||
}
|
}
|
||||||
op {
|
|
||||||
name: "CollectiveGather"
|
|
||||||
input_arg {
|
|
||||||
name: "input"
|
|
||||||
type_attr: "T"
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "ordering_token"
|
|
||||||
type: DT_RESOURCE
|
|
||||||
number_attr: "Nordering_token"
|
|
||||||
}
|
|
||||||
output_arg {
|
|
||||||
name: "data"
|
|
||||||
type_attr: "T"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "T"
|
|
||||||
type: "type"
|
|
||||||
allowed_values {
|
|
||||||
list {
|
|
||||||
type: DT_FLOAT
|
|
||||||
type: DT_HALF
|
|
||||||
type: DT_DOUBLE
|
|
||||||
type: DT_INT32
|
|
||||||
type: DT_INT64
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "group_size"
|
|
||||||
type: "int"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "group_key"
|
|
||||||
type: "int"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "instance_key"
|
|
||||||
type: "int"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "shape"
|
|
||||||
type: "shape"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "communication_hint"
|
|
||||||
type: "string"
|
|
||||||
default_value {
|
|
||||||
s: "auto"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "timeout_seconds"
|
|
||||||
type: "float"
|
|
||||||
default_value {
|
|
||||||
f: 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "Nordering_token"
|
|
||||||
type: "int"
|
|
||||||
default_value {
|
|
||||||
i: 0
|
|
||||||
}
|
|
||||||
has_minimum: true
|
|
||||||
}
|
|
||||||
is_stateful: true
|
|
||||||
}
|
|
||||||
|
@ -49,67 +49,3 @@ op {
|
|||||||
}
|
}
|
||||||
is_stateful: true
|
is_stateful: true
|
||||||
}
|
}
|
||||||
op {
|
|
||||||
name: "CollectiveGatherV2"
|
|
||||||
input_arg {
|
|
||||||
name: "input"
|
|
||||||
type_attr: "T"
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "group_size"
|
|
||||||
type: DT_INT32
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "group_key"
|
|
||||||
type: DT_INT32
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "instance_key"
|
|
||||||
type: DT_INT32
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "ordering_token"
|
|
||||||
type: DT_RESOURCE
|
|
||||||
number_attr: "Nordering_token"
|
|
||||||
}
|
|
||||||
output_arg {
|
|
||||||
name: "data"
|
|
||||||
type_attr: "T"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "T"
|
|
||||||
type: "type"
|
|
||||||
allowed_values {
|
|
||||||
list {
|
|
||||||
type: DT_FLOAT
|
|
||||||
type: DT_HALF
|
|
||||||
type: DT_DOUBLE
|
|
||||||
type: DT_INT32
|
|
||||||
type: DT_INT64
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "communication_hint"
|
|
||||||
type: "string"
|
|
||||||
default_value {
|
|
||||||
s: "auto"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "timeout_seconds"
|
|
||||||
type: "float"
|
|
||||||
default_value {
|
|
||||||
f: 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "Nordering_token"
|
|
||||||
type: "int"
|
|
||||||
default_value {
|
|
||||||
i: 0
|
|
||||||
}
|
|
||||||
has_minimum: true
|
|
||||||
}
|
|
||||||
is_stateful: true
|
|
||||||
}
|
|
||||||
|
@ -295,101 +295,3 @@ op {
|
|||||||
}
|
}
|
||||||
is_stateful: true
|
is_stateful: true
|
||||||
}
|
}
|
||||||
op {
|
|
||||||
name: "CollectiveReduce"
|
|
||||||
input_arg {
|
|
||||||
name: "input"
|
|
||||||
type_attr: "T"
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "ordering_token"
|
|
||||||
type: DT_RESOURCE
|
|
||||||
number_attr: "Nordering_token"
|
|
||||||
}
|
|
||||||
output_arg {
|
|
||||||
name: "data"
|
|
||||||
type_attr: "T"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "T"
|
|
||||||
type: "type"
|
|
||||||
allowed_values {
|
|
||||||
list {
|
|
||||||
type: DT_FLOAT
|
|
||||||
type: DT_HALF
|
|
||||||
type: DT_DOUBLE
|
|
||||||
type: DT_INT32
|
|
||||||
type: DT_INT64
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "group_size"
|
|
||||||
type: "int"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "group_key"
|
|
||||||
type: "int"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "instance_key"
|
|
||||||
type: "int"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "merge_op"
|
|
||||||
type: "string"
|
|
||||||
allowed_values {
|
|
||||||
list {
|
|
||||||
s: "Min"
|
|
||||||
s: "Max"
|
|
||||||
s: "Mul"
|
|
||||||
s: "Add"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "final_op"
|
|
||||||
type: "string"
|
|
||||||
allowed_values {
|
|
||||||
list {
|
|
||||||
s: "Id"
|
|
||||||
s: "Div"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "subdiv_offsets"
|
|
||||||
type: "list(int)"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "wait_for"
|
|
||||||
type: "list(int)"
|
|
||||||
default_value {
|
|
||||||
list {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "communication_hint"
|
|
||||||
type: "string"
|
|
||||||
default_value {
|
|
||||||
s: "auto"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "timeout_seconds"
|
|
||||||
type: "float"
|
|
||||||
default_value {
|
|
||||||
f: 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "Nordering_token"
|
|
||||||
type: "int"
|
|
||||||
default_value {
|
|
||||||
i: 0
|
|
||||||
}
|
|
||||||
has_minimum: true
|
|
||||||
}
|
|
||||||
is_stateful: true
|
|
||||||
}
|
|
||||||
|
@ -137,89 +137,3 @@ op {
|
|||||||
}
|
}
|
||||||
is_stateful: true
|
is_stateful: true
|
||||||
}
|
}
|
||||||
op {
|
|
||||||
name: "CollectiveReduceV2"
|
|
||||||
input_arg {
|
|
||||||
name: "input"
|
|
||||||
type_attr: "T"
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "group_size"
|
|
||||||
type: DT_INT32
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "group_key"
|
|
||||||
type: DT_INT32
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "instance_key"
|
|
||||||
type: DT_INT32
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "ordering_token"
|
|
||||||
type: DT_RESOURCE
|
|
||||||
number_attr: "Nordering_token"
|
|
||||||
}
|
|
||||||
output_arg {
|
|
||||||
name: "data"
|
|
||||||
type_attr: "T"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "T"
|
|
||||||
type: "type"
|
|
||||||
allowed_values {
|
|
||||||
list {
|
|
||||||
type: DT_FLOAT
|
|
||||||
type: DT_HALF
|
|
||||||
type: DT_DOUBLE
|
|
||||||
type: DT_INT32
|
|
||||||
type: DT_INT64
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "merge_op"
|
|
||||||
type: "string"
|
|
||||||
allowed_values {
|
|
||||||
list {
|
|
||||||
s: "Min"
|
|
||||||
s: "Max"
|
|
||||||
s: "Mul"
|
|
||||||
s: "Add"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "final_op"
|
|
||||||
type: "string"
|
|
||||||
allowed_values {
|
|
||||||
list {
|
|
||||||
s: "Id"
|
|
||||||
s: "Div"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "communication_hint"
|
|
||||||
type: "string"
|
|
||||||
default_value {
|
|
||||||
s: "auto"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "timeout_seconds"
|
|
||||||
type: "float"
|
|
||||||
default_value {
|
|
||||||
f: 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "Nordering_token"
|
|
||||||
type: "int"
|
|
||||||
default_value {
|
|
||||||
i: 0
|
|
||||||
}
|
|
||||||
has_minimum: true
|
|
||||||
}
|
|
||||||
is_stateful: true
|
|
||||||
}
|
|
||||||
|
@ -7502,11 +7502,6 @@ op {
|
|||||||
name: "input"
|
name: "input"
|
||||||
type_attr: "T"
|
type_attr: "T"
|
||||||
}
|
}
|
||||||
input_arg {
|
|
||||||
name: "ordering_token"
|
|
||||||
type: DT_RESOURCE
|
|
||||||
number_attr: "Nordering_token"
|
|
||||||
}
|
|
||||||
output_arg {
|
output_arg {
|
||||||
name: "data"
|
name: "data"
|
||||||
type_attr: "T"
|
type_attr: "T"
|
||||||
@ -7554,14 +7549,6 @@ op {
|
|||||||
f: 0
|
f: 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
attr {
|
|
||||||
name: "Nordering_token"
|
|
||||||
type: "int"
|
|
||||||
default_value {
|
|
||||||
i: 0
|
|
||||||
}
|
|
||||||
has_minimum: true
|
|
||||||
}
|
|
||||||
is_stateful: true
|
is_stateful: true
|
||||||
}
|
}
|
||||||
op {
|
op {
|
||||||
@ -7582,11 +7569,6 @@ op {
|
|||||||
name: "instance_key"
|
name: "instance_key"
|
||||||
type: DT_INT32
|
type: DT_INT32
|
||||||
}
|
}
|
||||||
input_arg {
|
|
||||||
name: "ordering_token"
|
|
||||||
type: DT_RESOURCE
|
|
||||||
number_attr: "Nordering_token"
|
|
||||||
}
|
|
||||||
output_arg {
|
output_arg {
|
||||||
name: "data"
|
name: "data"
|
||||||
type_attr: "T"
|
type_attr: "T"
|
||||||
@ -7618,14 +7600,6 @@ op {
|
|||||||
f: 0
|
f: 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
attr {
|
|
||||||
name: "Nordering_token"
|
|
||||||
type: "int"
|
|
||||||
default_value {
|
|
||||||
i: 0
|
|
||||||
}
|
|
||||||
has_minimum: true
|
|
||||||
}
|
|
||||||
is_stateful: true
|
is_stateful: true
|
||||||
}
|
}
|
||||||
op {
|
op {
|
||||||
@ -7674,11 +7648,6 @@ op {
|
|||||||
name: "input"
|
name: "input"
|
||||||
type_attr: "T"
|
type_attr: "T"
|
||||||
}
|
}
|
||||||
input_arg {
|
|
||||||
name: "ordering_token"
|
|
||||||
type: DT_RESOURCE
|
|
||||||
number_attr: "Nordering_token"
|
|
||||||
}
|
|
||||||
output_arg {
|
output_arg {
|
||||||
name: "data"
|
name: "data"
|
||||||
type_attr: "T"
|
type_attr: "T"
|
||||||
@ -7756,14 +7725,6 @@ op {
|
|||||||
f: 0
|
f: 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
attr {
|
|
||||||
name: "Nordering_token"
|
|
||||||
type: "int"
|
|
||||||
default_value {
|
|
||||||
i: 0
|
|
||||||
}
|
|
||||||
has_minimum: true
|
|
||||||
}
|
|
||||||
is_stateful: true
|
is_stateful: true
|
||||||
}
|
}
|
||||||
op {
|
op {
|
||||||
@ -7784,11 +7745,6 @@ op {
|
|||||||
name: "instance_key"
|
name: "instance_key"
|
||||||
type: DT_INT32
|
type: DT_INT32
|
||||||
}
|
}
|
||||||
input_arg {
|
|
||||||
name: "ordering_token"
|
|
||||||
type: DT_RESOURCE
|
|
||||||
number_attr: "Nordering_token"
|
|
||||||
}
|
|
||||||
output_arg {
|
output_arg {
|
||||||
name: "data"
|
name: "data"
|
||||||
type_attr: "T"
|
type_attr: "T"
|
||||||
@ -7842,14 +7798,6 @@ op {
|
|||||||
f: 0
|
f: 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
attr {
|
|
||||||
name: "Nordering_token"
|
|
||||||
type: "int"
|
|
||||||
default_value {
|
|
||||||
i: 0
|
|
||||||
}
|
|
||||||
has_minimum: true
|
|
||||||
}
|
|
||||||
is_stateful: true
|
is_stateful: true
|
||||||
}
|
}
|
||||||
op {
|
op {
|
||||||
|
@ -31,12 +31,10 @@ from tensorflow.python.distribute import test_util
|
|||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import collective_ops as _collective_ops
|
from tensorflow.python.ops import collective_ops as _collective_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -70,19 +68,6 @@ device_combination = (
|
|||||||
device='GPU', communication=['RING', 'NCCL'], required_gpus=2))
|
device='GPU', communication=['RING', 'NCCL'], required_gpus=2))
|
||||||
|
|
||||||
|
|
||||||
collective_op_combinations = combinations.times(
|
|
||||||
combinations.combine(
|
|
||||||
collective_op=[
|
|
||||||
combinations.NamedObject('all_reduce', CollectiveOpsV1.all_reduce),
|
|
||||||
combinations.NamedObject('all_reduce_v2',
|
|
||||||
CollectiveOpsV2.all_reduce),
|
|
||||||
combinations.NamedObject('all_gather', CollectiveOpsV1.all_gather),
|
|
||||||
combinations.NamedObject('all_gather_v2',
|
|
||||||
CollectiveOpsV2.all_gather),
|
|
||||||
],
|
|
||||||
mode='eager'), device_combination)
|
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(
|
combinations.times(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
@ -298,7 +283,20 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
run_and_assert(group_size=3, group_key=2)
|
run_and_assert(group_size=3, group_key=2)
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(collective_op_combinations)
|
@combinations.generate(
|
||||||
|
combinations.times(
|
||||||
|
combinations.combine(
|
||||||
|
collective_op=[
|
||||||
|
combinations.NamedObject('all_reduce',
|
||||||
|
CollectiveOpsV1.all_reduce),
|
||||||
|
combinations.NamedObject('all_reduce_v2',
|
||||||
|
CollectiveOpsV2.all_reduce),
|
||||||
|
combinations.NamedObject('all_gather',
|
||||||
|
CollectiveOpsV1.all_gather),
|
||||||
|
combinations.NamedObject('all_gather_v2',
|
||||||
|
CollectiveOpsV2.all_gather),
|
||||||
|
],
|
||||||
|
mode='eager'), device_combination))
|
||||||
class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -551,7 +549,20 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
communication_hint=communication)
|
communication_hint=communication)
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(collective_op_combinations)
|
@combinations.generate(
|
||||||
|
combinations.times(
|
||||||
|
combinations.combine(
|
||||||
|
collective_op=[
|
||||||
|
combinations.NamedObject('all_reduce',
|
||||||
|
CollectiveOpsV1.all_reduce),
|
||||||
|
combinations.NamedObject('all_reduce_v2',
|
||||||
|
CollectiveOpsV2.all_reduce),
|
||||||
|
combinations.NamedObject('all_gather',
|
||||||
|
CollectiveOpsV1.all_gather),
|
||||||
|
combinations.NamedObject('all_gather_v2',
|
||||||
|
CollectiveOpsV2.all_gather),
|
||||||
|
],
|
||||||
|
mode='eager'), device_combination))
|
||||||
class TimeoutTest(test.TestCase, parameterized.TestCase):
|
class TimeoutTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -676,85 +687,6 @@ class TimeoutTest(test.TestCase, parameterized.TestCase):
|
|||||||
communication_hint=communication)
|
communication_hint=communication)
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(collective_op_combinations)
|
|
||||||
class OrderingTest(test.TestCase, parameterized.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
_setup_context()
|
|
||||||
super().setUp()
|
|
||||||
|
|
||||||
def testOrdering(self, collective_op, device, communication):
|
|
||||||
dev0 = '/device:%s:0' % device
|
|
||||||
dev1 = '/device:%s:1' % device
|
|
||||||
group_size = 2
|
|
||||||
group_key = 100
|
|
||||||
instance_key = 100
|
|
||||||
in_tensor = constant_op.constant([1.])
|
|
||||||
|
|
||||||
with ops.device(dev0):
|
|
||||||
token0 = resource_variable_ops.ResourceVariable(0.)
|
|
||||||
with ops.device(dev1):
|
|
||||||
token1 = resource_variable_ops.ResourceVariable(0.)
|
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def f():
|
|
||||||
# Launch the first collective with token.
|
|
||||||
with ops.device(dev0):
|
|
||||||
collective_op(
|
|
||||||
in_tensor,
|
|
||||||
group_size,
|
|
||||||
group_key,
|
|
||||||
instance_key,
|
|
||||||
ordering_token=token0.handle)
|
|
||||||
with ops.device(dev1):
|
|
||||||
collective_op(
|
|
||||||
in_tensor,
|
|
||||||
group_size,
|
|
||||||
group_key,
|
|
||||||
instance_key,
|
|
||||||
ordering_token=token1.handle)
|
|
||||||
# Launch the second collective without token.
|
|
||||||
with ops.device(dev0):
|
|
||||||
collective_op(in_tensor, group_size, group_key, instance_key)
|
|
||||||
with ops.device(dev1):
|
|
||||||
collective_op(in_tensor, group_size, group_key, instance_key)
|
|
||||||
# Launch the third collective with token.
|
|
||||||
with ops.device(dev0):
|
|
||||||
collective_op(
|
|
||||||
in_tensor,
|
|
||||||
group_size,
|
|
||||||
group_key,
|
|
||||||
instance_key,
|
|
||||||
ordering_token=token0.handle)
|
|
||||||
with ops.device(dev1):
|
|
||||||
collective_op(
|
|
||||||
in_tensor,
|
|
||||||
group_size,
|
|
||||||
group_key,
|
|
||||||
instance_key,
|
|
||||||
ordering_token=token1.handle)
|
|
||||||
|
|
||||||
graph = f.get_concrete_function().graph
|
|
||||||
for device in [dev0, dev1]:
|
|
||||||
# Try to find the third collective, which should have the first collective
|
|
||||||
# as a control input.
|
|
||||||
third = None
|
|
||||||
for op in graph.get_operations():
|
|
||||||
if (op.type.startswith('Collective') and op.device.endswith(device) and
|
|
||||||
op.control_inputs and
|
|
||||||
op.control_inputs[0].type.startswith('Collective')):
|
|
||||||
self.assertIsNone(third)
|
|
||||||
third = op
|
|
||||||
self.assertIsNotNone(third)
|
|
||||||
# Verify it's not the second collective by looking at the inputs.
|
|
||||||
self.assertTrue(any(v.dtype == dtypes.resource for v in third.inputs))
|
|
||||||
first = third.control_inputs[0]
|
|
||||||
self.assertEqual(third.device, first.device)
|
|
||||||
# Verify it's not the second collective by looking at the inputs.
|
|
||||||
self.assertTrue(any(v.dtype == dtypes.resource for v in first.inputs))
|
|
||||||
self.assertEmpty(first.control_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
def _setup_context():
|
def _setup_context():
|
||||||
context._reset_context()
|
context._reset_context()
|
||||||
test_util.set_logical_devices_to_at_least('CPU', 4)
|
test_util.set_logical_devices_to_at_least('CPU', 4)
|
||||||
|
@ -28,8 +28,7 @@ def all_reduce(t,
|
|||||||
final_op='Id',
|
final_op='Id',
|
||||||
subdiv_offsets=(0,),
|
subdiv_offsets=(0,),
|
||||||
communication_hint='auto',
|
communication_hint='auto',
|
||||||
timeout=0,
|
timeout=0):
|
||||||
ordering_token=None):
|
|
||||||
"""Reduces tensors collectively, across devices.
|
"""Reduces tensors collectively, across devices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -51,9 +50,6 @@ def all_reduce(t,
|
|||||||
timeout: a float. If set to a non zero, set a completion timeout to detect
|
timeout: a float. If set to a non zero, set a completion timeout to detect
|
||||||
staleness. If the timer goes off, a DeadlineExceededError is raised. The
|
staleness. If the timer goes off, a DeadlineExceededError is raised. The
|
||||||
timeout value in seconds. This feature is experimental.
|
timeout value in seconds. This feature is experimental.
|
||||||
ordering_token: an optional resource tensor to pass to the op as inputs.
|
|
||||||
They aren't used by the kernel but allow AutoControlDependency to order
|
|
||||||
the collectives with control dependencies.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Op implementing the distributed reduction.
|
An Op implementing the distributed reduction.
|
||||||
@ -63,8 +59,6 @@ def all_reduce(t,
|
|||||||
"""
|
"""
|
||||||
if group_size < 1:
|
if group_size < 1:
|
||||||
raise ValueError('Parameter group_size to all_reduce must be at least 1.')
|
raise ValueError('Parameter group_size to all_reduce must be at least 1.')
|
||||||
if ordering_token is not None:
|
|
||||||
ordering_token = [ordering_token]
|
|
||||||
return gen_collective_ops.collective_reduce(
|
return gen_collective_ops.collective_reduce(
|
||||||
t,
|
t,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
@ -74,8 +68,7 @@ def all_reduce(t,
|
|||||||
final_op=final_op,
|
final_op=final_op,
|
||||||
subdiv_offsets=subdiv_offsets,
|
subdiv_offsets=subdiv_offsets,
|
||||||
communication_hint=communication_hint.lower(),
|
communication_hint=communication_hint.lower(),
|
||||||
timeout_seconds=timeout,
|
timeout_seconds=timeout)
|
||||||
ordering_token=ordering_token or [])
|
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_v2(t,
|
def all_reduce_v2(t,
|
||||||
@ -85,8 +78,7 @@ def all_reduce_v2(t,
|
|||||||
merge_op='Add',
|
merge_op='Add',
|
||||||
final_op='Id',
|
final_op='Id',
|
||||||
communication_hint='auto',
|
communication_hint='auto',
|
||||||
timeout=0,
|
timeout=0):
|
||||||
ordering_token=None):
|
|
||||||
"""Reduces tensors collectively, across devices.
|
"""Reduces tensors collectively, across devices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -106,15 +98,10 @@ def all_reduce_v2(t,
|
|||||||
timeout: a float. If set to a non zero, set a completion timeout to detect
|
timeout: a float. If set to a non zero, set a completion timeout to detect
|
||||||
staleness. If the timer goes off, a DeadlineExceededError is raised. The
|
staleness. If the timer goes off, a DeadlineExceededError is raised. The
|
||||||
timeout value in seconds. This feature is experimental.
|
timeout value in seconds. This feature is experimental.
|
||||||
ordering_token: an optional resource tensor to pass to the op as inputs.
|
|
||||||
They aren't used by the kernel but allow AutoControlDependency to order
|
|
||||||
the collectives with control dependencies.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Op implementing the distributed reduction.
|
An Op implementing the distributed reduction.
|
||||||
"""
|
"""
|
||||||
if ordering_token is not None:
|
|
||||||
ordering_token = [ordering_token]
|
|
||||||
return gen_collective_ops.collective_reduce_v2(
|
return gen_collective_ops.collective_reduce_v2(
|
||||||
t,
|
t,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
@ -123,8 +110,7 @@ def all_reduce_v2(t,
|
|||||||
merge_op=merge_op,
|
merge_op=merge_op,
|
||||||
final_op=final_op,
|
final_op=final_op,
|
||||||
communication_hint=communication_hint.lower(),
|
communication_hint=communication_hint.lower(),
|
||||||
timeout_seconds=timeout,
|
timeout_seconds=timeout)
|
||||||
ordering_token=ordering_token or [])
|
|
||||||
|
|
||||||
|
|
||||||
def all_gather(t,
|
def all_gather(t,
|
||||||
@ -132,8 +118,7 @@ def all_gather(t,
|
|||||||
group_key,
|
group_key,
|
||||||
instance_key,
|
instance_key,
|
||||||
communication_hint='auto',
|
communication_hint='auto',
|
||||||
timeout=0,
|
timeout=0):
|
||||||
ordering_token=None):
|
|
||||||
"""Accumulates tensors collectively, across devices, along first dimension.
|
"""Accumulates tensors collectively, across devices, along first dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -148,9 +133,6 @@ def all_gather(t,
|
|||||||
timeout: a float. If set to a non zero, set a completion timeout to detect
|
timeout: a float. If set to a non zero, set a completion timeout to detect
|
||||||
staleness. If the timer goes off, a DeadlineExceededError is raised. The
|
staleness. If the timer goes off, a DeadlineExceededError is raised. The
|
||||||
timeout value in seconds. This feature is experimental.
|
timeout value in seconds. This feature is experimental.
|
||||||
ordering_token: an optional resource tensor to pass to the op as inputs.
|
|
||||||
They aren't used by the kernel but allow AutoControlDependency to order
|
|
||||||
the collectives with control dependencies.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Op implementing the distributed operation.
|
An Op implementing the distributed operation.
|
||||||
@ -160,8 +142,6 @@ def all_gather(t,
|
|||||||
"""
|
"""
|
||||||
if group_size < 1:
|
if group_size < 1:
|
||||||
raise ValueError('Parameter group_size to all_gather must be at least 1.')
|
raise ValueError('Parameter group_size to all_gather must be at least 1.')
|
||||||
if ordering_token is not None:
|
|
||||||
ordering_token = [ordering_token]
|
|
||||||
return gen_collective_ops.collective_gather(
|
return gen_collective_ops.collective_gather(
|
||||||
t,
|
t,
|
||||||
shape=[0],
|
shape=[0],
|
||||||
@ -169,8 +149,7 @@ def all_gather(t,
|
|||||||
group_key=group_key,
|
group_key=group_key,
|
||||||
instance_key=instance_key,
|
instance_key=instance_key,
|
||||||
communication_hint=communication_hint.lower(),
|
communication_hint=communication_hint.lower(),
|
||||||
timeout_seconds=timeout,
|
timeout_seconds=timeout)
|
||||||
ordering_token=ordering_token or [])
|
|
||||||
|
|
||||||
|
|
||||||
def all_gather_v2(t,
|
def all_gather_v2(t,
|
||||||
@ -178,8 +157,7 @@ def all_gather_v2(t,
|
|||||||
group_key,
|
group_key,
|
||||||
instance_key,
|
instance_key,
|
||||||
communication_hint='auto',
|
communication_hint='auto',
|
||||||
timeout=0,
|
timeout=0):
|
||||||
ordering_token=None):
|
|
||||||
"""Accumulates tensors collectively, across devices, along first dimension.
|
"""Accumulates tensors collectively, across devices, along first dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -195,23 +173,17 @@ def all_gather_v2(t,
|
|||||||
timeout: a float. If set to a non zero, set a completion timeout to detect
|
timeout: a float. If set to a non zero, set a completion timeout to detect
|
||||||
staleness. If the timer goes off, a DeadlineExceededError is raised. The
|
staleness. If the timer goes off, a DeadlineExceededError is raised. The
|
||||||
timeout value in seconds. This feature is experimental.
|
timeout value in seconds. This feature is experimental.
|
||||||
ordering_token: an optional resource tensor to pass to the op as inputs.
|
|
||||||
They aren't used by the kernel but allow AutoControlDependency to order
|
|
||||||
the collectives with control dependencies.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Op implementing the distributed operation.
|
An Op implementing the distributed operation.
|
||||||
"""
|
"""
|
||||||
if ordering_token is not None:
|
|
||||||
ordering_token = [ordering_token]
|
|
||||||
return gen_collective_ops.collective_gather_v2(
|
return gen_collective_ops.collective_gather_v2(
|
||||||
t,
|
t,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
group_key=group_key,
|
group_key=group_key,
|
||||||
instance_key=instance_key,
|
instance_key=instance_key,
|
||||||
communication_hint=communication_hint.lower(),
|
communication_hint=communication_hint.lower(),
|
||||||
timeout_seconds=timeout,
|
timeout_seconds=timeout)
|
||||||
ordering_token=ordering_token or [])
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast_send(t,
|
def broadcast_send(t,
|
||||||
|
@ -758,11 +758,11 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveGather"
|
name: "CollectiveGather"
|
||||||
argspec: "args=[\'input\', \'ordering_token\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveGatherV2"
|
name: "CollectiveGatherV2"
|
||||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'ordering_token\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectivePermute"
|
name: "CollectivePermute"
|
||||||
@ -770,11 +770,11 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveReduce"
|
name: "CollectiveReduce"
|
||||||
argspec: "args=[\'input\', \'ordering_token\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveReduceV2"
|
name: "CollectiveReduceV2"
|
||||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'ordering_token\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CombinedNonMaxSuppression"
|
name: "CombinedNonMaxSuppression"
|
||||||
|
@ -758,11 +758,11 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveGather"
|
name: "CollectiveGather"
|
||||||
argspec: "args=[\'input\', \'ordering_token\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveGatherV2"
|
name: "CollectiveGatherV2"
|
||||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'ordering_token\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectivePermute"
|
name: "CollectivePermute"
|
||||||
@ -770,11 +770,11 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveReduce"
|
name: "CollectiveReduce"
|
||||||
argspec: "args=[\'input\', \'ordering_token\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'subdiv_offsets\', \'wait_for\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CollectiveReduceV2"
|
name: "CollectiveReduceV2"
|
||||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'ordering_token\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CombinedNonMaxSuppression"
|
name: "CombinedNonMaxSuppression"
|
||||||
|
Loading…
Reference in New Issue
Block a user