Add optional _shapes/_handle_shapes/_handle_dtypes attributes to _Arg node.

PiperOrigin-RevId: 239512052
This commit is contained in:
Tong Shen 2019-03-20 17:43:52 -07:00 committed by TensorFlower Gardener
parent b22875dac9
commit 41b7433161
2 changed files with 67 additions and 1 deletions

View File

@ -1148,6 +1148,10 @@ tf_gen_op_libs(
"summary_ops",
"training_ops",
],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
tf_gen_op_libs(

View File

@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@ -25,7 +28,54 @@ REGISTER_SYSTEM_OP("_Arg")
.Attr("index: int >= 0")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* context) {
context->set_output(0, context->UnknownShape());
const AttrValue* dtype_attr = context->attrs().Find("T");
if (!dtype_attr) {
return errors::InvalidArgument(
"_Arg node does not have attribute \"T\"");
}
if (dtype_attr->type() == DT_RESOURCE) {
const AttrValue* dtype_attr = context->attrs().Find("_handle_dtypes");
const AttrValue* shape_attr = context->attrs().Find("_handle_shapes");
if (dtype_attr && shape_attr) {
if (dtype_attr->list().type().empty()) {
return errors::InvalidArgument(
"Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
if (shape_attr->list().shape().empty()) {
return errors::InvalidArgument(
"Invalid \"_handle_shapes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
DataType dtype = dtype_attr->list().type(0);
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
shape_inference::ShapeHandle shape_handle;
TF_RETURN_IF_ERROR(
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
context->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{
{shape_handle, dtype}});
} else {
context->set_output(0, context->UnknownShape());
}
} else {
const AttrValue* shape_attr = context->attrs().Find("_shapes");
if (shape_attr && shape_attr->has_shape()) {
if (shape_attr->list().shape().empty()) {
return errors::InvalidArgument(
"Invalid \"_shapes\" attribute value for _Arg node: ",
shape_attr->DebugString());
}
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
shape_inference::ShapeHandle shape_handle;
TF_RETURN_IF_ERROR(
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
context->set_output(0, shape_handle);
} else {
context->set_output(0, context->UnknownShape());
}
}
return Status::OK();
})
.Doc(R"doc(
@ -33,6 +83,18 @@ A graph node which represents an argument to a function.
output: The argument.
index: This argument is the index-th argument of the function.
Attributes for shape inference:
1. _shapes: this attribute can be set on an _Arg node producing non-resource
output(s). If set, its value should contain a list of TensorShapeProto
describing the shape(s) of the tensor(s) this _Arg node will produce. If set,
_Arg node's shape inference function will use it as the node's output shapes.
2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg
node producing resource output(s). If set, value of _handle_dtypes should
contain the dtype(s) of the resource(s) and value of _handle_shapes should
contain the shape(s) of the resource(s). If both attributes are set, _Arg
node's shape inference function will use their values as the node's output
type(s) and shape(s).
)doc");
REGISTER_SYSTEM_OP("_DeviceArg")