Add optional _shapes/_handle_shapes/_handle_dtypes attributes to _Arg node.
PiperOrigin-RevId: 239512052
This commit is contained in:
parent
b22875dac9
commit
41b7433161
@ -1148,6 +1148,10 @@ tf_gen_op_libs(
|
||||
"summary_ops",
|
||||
"training_ops",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
|
@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -25,7 +28,54 @@ REGISTER_SYSTEM_OP("_Arg")
|
||||
.Attr("index: int >= 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* context) {
|
||||
const AttrValue* dtype_attr = context->attrs().Find("T");
|
||||
if (!dtype_attr) {
|
||||
return errors::InvalidArgument(
|
||||
"_Arg node does not have attribute \"T\"");
|
||||
}
|
||||
|
||||
if (dtype_attr->type() == DT_RESOURCE) {
|
||||
const AttrValue* dtype_attr = context->attrs().Find("_handle_dtypes");
|
||||
const AttrValue* shape_attr = context->attrs().Find("_handle_shapes");
|
||||
if (dtype_attr && shape_attr) {
|
||||
if (dtype_attr->list().type().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
|
||||
shape_attr->DebugString());
|
||||
}
|
||||
if (shape_attr->list().shape().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid \"_handle_shapes\" attribute value for _Arg node: ",
|
||||
shape_attr->DebugString());
|
||||
}
|
||||
DataType dtype = dtype_attr->list().type(0);
|
||||
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
|
||||
shape_inference::ShapeHandle shape_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
|
||||
context->set_output_handle_shapes_and_types(
|
||||
0, std::vector<shape_inference::ShapeAndType>{
|
||||
{shape_handle, dtype}});
|
||||
} else {
|
||||
context->set_output(0, context->UnknownShape());
|
||||
}
|
||||
} else {
|
||||
const AttrValue* shape_attr = context->attrs().Find("_shapes");
|
||||
if (shape_attr && shape_attr->has_shape()) {
|
||||
if (shape_attr->list().shape().empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid \"_shapes\" attribute value for _Arg node: ",
|
||||
shape_attr->DebugString());
|
||||
}
|
||||
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
|
||||
shape_inference::ShapeHandle shape_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
|
||||
context->set_output(0, shape_handle);
|
||||
} else {
|
||||
context->set_output(0, context->UnknownShape());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
@ -33,6 +83,18 @@ A graph node which represents an argument to a function.
|
||||
|
||||
output: The argument.
|
||||
index: This argument is the index-th argument of the function.
|
||||
|
||||
Attributes for shape inference:
|
||||
1. _shapes: this attribute can be set on an _Arg node producing non-resource
|
||||
output(s). If set, its value should contain a list of TensorShapeProto
|
||||
describing the shape(s) of the tensor(s) this _Arg node will produce. If set,
|
||||
_Arg node's shape inference function will use it as the node's output shapes.
|
||||
2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg
|
||||
node producing resource output(s). If set, value of _handle_dtypes should
|
||||
contain the dtype(s) of the resource(s) and value of _handle_shapes should
|
||||
contain the shape(s) of the resource(s). If both attributes are set, _Arg
|
||||
node's shape inference function will use their values as the node's output
|
||||
type(s) and shape(s).
|
||||
)doc");
|
||||
|
||||
REGISTER_SYSTEM_OP("_DeviceArg")
|
||||
|
Loading…
Reference in New Issue
Block a user