From 41b74331612aa264565d97500ab530206254e226 Mon Sep 17 00:00:00 2001 From: Tong Shen Date: Wed, 20 Mar 2019 17:43:52 -0700 Subject: [PATCH] Add optional _shapes/_handle_shapes/_handle_dtypes attributes to _Arg node. PiperOrigin-RevId: 239512052 --- tensorflow/core/BUILD | 4 ++ tensorflow/core/ops/function_ops.cc | 64 ++++++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 6445ecd975d..013d7d3b2b3 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1148,6 +1148,10 @@ tf_gen_op_libs( "summary_ops", "training_ops", ], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], ) tf_gen_op_libs( diff --git a/tensorflow/core/ops/function_ops.cc b/tensorflow/core/ops/function_ops.cc index 8e86dd9f780..4c99815a05b 100644 --- a/tensorflow/core/ops/function_ops.cc +++ b/tensorflow/core/ops/function_ops.cc @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -25,7 +28,54 @@ REGISTER_SYSTEM_OP("_Arg") .Attr("index: int >= 0") .SetIsStateful() .SetShapeFn([](shape_inference::InferenceContext* context) { - context->set_output(0, context->UnknownShape()); + const AttrValue* dtype_attr = context->attrs().Find("T"); + if (!dtype_attr) { + return errors::InvalidArgument( + "_Arg node does not have attribute \"T\""); + } + + if (dtype_attr->type() == DT_RESOURCE) { + const AttrValue* dtype_attr = context->attrs().Find("_handle_dtypes"); + const AttrValue* shape_attr = context->attrs().Find("_handle_shapes"); + if (dtype_attr && shape_attr) { + if (dtype_attr->list().type().empty()) { + return errors::InvalidArgument( + "Invalid \"_handle_dtypes\" attribute value for _Arg node: ", + shape_attr->DebugString()); + } + if (shape_attr->list().shape().empty()) { + return errors::InvalidArgument( + "Invalid \"_handle_shapes\" attribute value for _Arg node: ", + shape_attr->DebugString()); + } + DataType dtype = dtype_attr->list().type(0); + const TensorShapeProto& shape_proto = shape_attr->list().shape(0); + shape_inference::ShapeHandle shape_handle; + TF_RETURN_IF_ERROR( + context->MakeShapeFromShapeProto(shape_proto, &shape_handle)); + context->set_output_handle_shapes_and_types( + 0, std::vector{ + {shape_handle, dtype}}); + } else { + context->set_output(0, context->UnknownShape()); + } + } else { + const AttrValue* shape_attr = context->attrs().Find("_shapes"); + if (shape_attr && shape_attr->has_shape()) { + if (shape_attr->list().shape().empty()) { + return errors::InvalidArgument( + "Invalid \"_shapes\" attribute value for _Arg node: ", + shape_attr->DebugString()); + } + const TensorShapeProto& shape_proto = shape_attr->list().shape(0); + shape_inference::ShapeHandle shape_handle; + TF_RETURN_IF_ERROR( + context->MakeShapeFromShapeProto(shape_proto, &shape_handle)); + context->set_output(0, shape_handle); + } else { + context->set_output(0, context->UnknownShape()); + } + } return Status::OK(); }) .Doc(R"doc( @@ -33,6 +83,18 @@ A graph node which represents an argument to a function. output: The argument. index: This argument is the index-th argument of the function. + +Attributes for shape inference: +1. _shapes: this attribute can be set on an _Arg node producing non-resource + output(s). If set, its value should contain a list of TensorShapeProto + describing the shape(s) of the tensor(s) this _Arg node will produce. If set, + _Arg node's shape inference function will use it as the node's output shapes. +2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg + node producing resource output(s). If set, value of _handle_dtypes should + contain the dtype(s) of the resource(s) and value of _handle_shapes should + contain the shape(s) of the resource(s). If both attributes are set, _Arg + node's shape inference function will use their values as the node's output + type(s) and shape(s). )doc"); REGISTER_SYSTEM_OP("_DeviceArg")