Exported Send/Recv in tf.raw_ops

This is done to streamline Channel implementation in Lingvo which currently
relies on a mutable op_def_registry.

PiperOrigin-RevId: 270841220
This commit is contained in:
Sergei Lebedev 2019-09-23 23:44:33 -07:00 committed by TensorFlower Gardener
parent 4aa7dbce08
commit a7d82bc543
8 changed files with 148 additions and 11 deletions

View File

@ -0,0 +1,44 @@
op {
graph_op_name: "Recv"
visibility: HIDDEN
out_arg {
name: "tensor"
description: <<END
The tensor to receive.
END
}
attr {
name: "tensor_name"
description: <<END
The name of the tensor to receive.
END
}
attr {
name: "send_device"
description: <<END
The name of the device sending the tensor.
END
}
attr {
name: "send_device_incarnation"
description: <<END
The current incarnation of send_device.
END
}
attr {
name: "recv_device"
description: <<END
The name of the device receiving the tensor.
END
}
attr {
name: "client_terminated"
description: <<END
If set to true, this indicates that the node was added
to the graph as a result of a client-side feed or fetch of Tensor data,
in which case the corresponding send or recv is expected to be managed
locally by the caller.
END
}
summary: "Receives the named tensor from send_device on recv_device."
}

View File

@ -0,0 +1,44 @@
op {
graph_op_name: "Send"
visibility: HIDDEN
in_arg {
name: "tensor"
description: <<END
The tensor to send.
END
}
attr {
name: "tensor_name"
description: <<END
The name of the tensor to send.
END
}
attr {
name: "send_device"
description: <<END
The name of the device sending the tensor.
END
}
attr {
name: "send_device_incarnation"
description: <<END
The current incarnation of send_device.
END
}
attr {
name: "recv_device"
description: <<END
The name of the device receiving the tensor.
END
}
attr {
name: "client_terminated"
description: <<END
If set to true, this indicates that the node was added
to the graph as a result of a client-side feed or fetch of Tensor data,
in which case the corresponding send or recv is expected to be managed
locally by the caller.
END
}
summary: "Sends the named tensor from send_device to recv_device."
}

View File

@ -110,6 +110,10 @@ void SendOp::Compute(OpKernelContext* ctx) {
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_DEFAULT), SendOp);
// Public alias. Added for use in Lingvo.
REGISTER_KERNEL_BUILDER(Name("Send").Device(DEVICE_CPU), SendOp);
REGISTER_KERNEL_BUILDER(Name("Send").Device(DEVICE_DEFAULT), SendOp);
REGISTER_KERNEL_BUILDER(
Name("_HostSend").Device(DEVICE_DEFAULT).HostMemory("tensor"), SendOp);
@ -191,6 +195,10 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_DEFAULT), RecvOp);
// Public alias. Added for use in Lingvo.
REGISTER_KERNEL_BUILDER(Name("Recv").Device(DEVICE_CPU), RecvOp);
REGISTER_KERNEL_BUILDER(Name("Recv").Device(DEVICE_DEFAULT), RecvOp);
REGISTER_KERNEL_BUILDER(
Name("_HostRecv").Device(DEVICE_DEFAULT).HostMemory("tensor"), RecvOp);

View File

@ -42,6 +42,17 @@ client_terminated: If set to true, this indicates that the node was added
locally by the caller.
)doc");
REGISTER_OP("Send")
.Input("tensor: T")
.Attr("T: type")
.Attr("tensor_name: string")
.Attr("send_device: string")
.Attr("send_device_incarnation: int")
.Attr("recv_device: string")
.Attr("client_terminated: bool = false")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("_Recv")
.Output("tensor: tensor_type")
.Attr("tensor_type: type")
@ -66,6 +77,17 @@ client_terminated: If set to true, this indicates that the node was added
locally by the caller.
)doc");
REGISTER_OP("Recv")
.Output("tensor: tensor_type")
.Attr("tensor_type: type")
.Attr("tensor_name: string")
.Attr("send_device: string")
.Attr("send_device_incarnation: int")
.Attr("recv_device: string")
.Attr("client_terminated: bool = false")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("_HostSend")
.Input("tensor: T")
.Attr("T: type")

View File

@ -145,6 +145,7 @@ py_library(
":rnn_ops_gen",
":saver_test_utils",
":script_ops",
":sendrecv_ops_gen",
":session_ops",
":sets",
":sparse_ops",
@ -2348,6 +2349,13 @@ tf_gen_op_wrapper_private_py(
visibility = ["//tensorflow/contrib/rnn:__pkg__"],
)
tf_gen_op_wrapper_private_py(
name = "sendrecv_ops_gen",
deps = [
"//tensorflow/core:sendrecv_ops_op_lib",
],
)
tf_gen_op_wrapper_private_py(
name = "tpu_ops_gen",
visibility = [

View File

@ -108,17 +108,12 @@ from tensorflow.python.tpu import api
from tensorflow.python.user_ops import user_ops
from tensorflow.python.util import compat
# Import audio ops to make sure the ops are registered.
from tensorflow.python.ops import gen_audio_ops as _
# Import boosted trees ops to make sure the ops are registered (but unused).
from tensorflow.python.ops import gen_boosted_trees_ops as _gen_boosted_trees_ops
# Import cudnn rnn ops to make sure their ops are registered.
from tensorflow.python.ops import gen_cudnn_rnn_ops as _
# Import rnn_ops to make sure their ops are registered.
from tensorflow.python.ops import gen_rnn_ops as _
# Import to make sure the ops are registered.
from tensorflow.python.ops import gen_audio_ops
from tensorflow.python.ops import gen_boosted_trees_ops
from tensorflow.python.ops import gen_cudnn_rnn_ops
from tensorflow.python.ops import gen_rnn_ops
from tensorflow.python.ops import gen_sendrecv_ops
# Import the names from python/training.py as train.Name.
from tensorflow.python.training import training as train

View File

@ -3044,6 +3044,10 @@ tf_module {
name: "RecordInput"
argspec: "args=[\'file_pattern\', \'file_random_seed\', \'file_shuffle_shift_ratio\', \'file_buffer_size\', \'file_parallelism\', \'batch_size\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'301\', \'0\', \'10000\', \'16\', \'32\', \'\', \'None\'], "
}
member_method {
name: "Recv"
argspec: "args=[\'tensor_type\', \'tensor_name\', \'send_device\', \'send_device_incarnation\', \'recv_device\', \'client_terminated\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "RecvTPUEmbeddingActivations"
argspec: "args=[\'num_outputs\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -3628,6 +3632,10 @@ tf_module {
name: "SeluGrad"
argspec: "args=[\'gradients\', \'outputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Send"
argspec: "args=[\'tensor\', \'tensor_name\', \'send_device\', \'send_device_incarnation\', \'recv_device\', \'client_terminated\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SendTPUEmbeddingGradients"
argspec: "args=[\'inputs\', \'learning_rates\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -3044,6 +3044,10 @@ tf_module {
name: "RecordInput"
argspec: "args=[\'file_pattern\', \'file_random_seed\', \'file_shuffle_shift_ratio\', \'file_buffer_size\', \'file_parallelism\', \'batch_size\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'301\', \'0\', \'10000\', \'16\', \'32\', \'\', \'None\'], "
}
member_method {
name: "Recv"
argspec: "args=[\'tensor_type\', \'tensor_name\', \'send_device\', \'send_device_incarnation\', \'recv_device\', \'client_terminated\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "RecvTPUEmbeddingActivations"
argspec: "args=[\'num_outputs\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -3628,6 +3632,10 @@ tf_module {
name: "SeluGrad"
argspec: "args=[\'gradients\', \'outputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Send"
argspec: "args=[\'tensor\', \'tensor_name\', \'send_device\', \'send_device_incarnation\', \'recv_device\', \'client_terminated\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SendTPUEmbeddingGradients"
argspec: "args=[\'inputs\', \'learning_rates\', \'config\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "