Add optional _shapes/_handle_shapes/_handle_dtypes attributes to _Arg node.
PiperOrigin-RevId: 239512052
This commit is contained in:
parent
b22875dac9
commit
41b7433161
@ -1148,6 +1148,10 @@ tf_gen_op_libs(
|
|||||||
"summary_ops",
|
"summary_ops",
|
||||||
"training_ops",
|
"training_ops",
|
||||||
],
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_gen_op_libs(
|
tf_gen_op_libs(
|
||||||
|
@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -25,7 +28,54 @@ REGISTER_SYSTEM_OP("_Arg")
|
|||||||
.Attr("index: int >= 0")
|
.Attr("index: int >= 0")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.SetShapeFn([](shape_inference::InferenceContext* context) {
|
.SetShapeFn([](shape_inference::InferenceContext* context) {
|
||||||
context->set_output(0, context->UnknownShape());
|
const AttrValue* dtype_attr = context->attrs().Find("T");
|
||||||
|
if (!dtype_attr) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"_Arg node does not have attribute \"T\"");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dtype_attr->type() == DT_RESOURCE) {
|
||||||
|
const AttrValue* dtype_attr = context->attrs().Find("_handle_dtypes");
|
||||||
|
const AttrValue* shape_attr = context->attrs().Find("_handle_shapes");
|
||||||
|
if (dtype_attr && shape_attr) {
|
||||||
|
if (dtype_attr->list().type().empty()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
|
||||||
|
shape_attr->DebugString());
|
||||||
|
}
|
||||||
|
if (shape_attr->list().shape().empty()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid \"_handle_shapes\" attribute value for _Arg node: ",
|
||||||
|
shape_attr->DebugString());
|
||||||
|
}
|
||||||
|
DataType dtype = dtype_attr->list().type(0);
|
||||||
|
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
|
||||||
|
shape_inference::ShapeHandle shape_handle;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
|
||||||
|
context->set_output_handle_shapes_and_types(
|
||||||
|
0, std::vector<shape_inference::ShapeAndType>{
|
||||||
|
{shape_handle, dtype}});
|
||||||
|
} else {
|
||||||
|
context->set_output(0, context->UnknownShape());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const AttrValue* shape_attr = context->attrs().Find("_shapes");
|
||||||
|
if (shape_attr && shape_attr->has_shape()) {
|
||||||
|
if (shape_attr->list().shape().empty()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid \"_shapes\" attribute value for _Arg node: ",
|
||||||
|
shape_attr->DebugString());
|
||||||
|
}
|
||||||
|
const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
|
||||||
|
shape_inference::ShapeHandle shape_handle;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
|
||||||
|
context->set_output(0, shape_handle);
|
||||||
|
} else {
|
||||||
|
context->set_output(0, context->UnknownShape());
|
||||||
|
}
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
@ -33,6 +83,18 @@ A graph node which represents an argument to a function.
|
|||||||
|
|
||||||
output: The argument.
|
output: The argument.
|
||||||
index: This argument is the index-th argument of the function.
|
index: This argument is the index-th argument of the function.
|
||||||
|
|
||||||
|
Attributes for shape inference:
|
||||||
|
1. _shapes: this attribute can be set on an _Arg node producing non-resource
|
||||||
|
output(s). If set, its value should contain a list of TensorShapeProto
|
||||||
|
describing the shape(s) of the tensor(s) this _Arg node will produce. If set,
|
||||||
|
_Arg node's shape inference function will use it as the node's output shapes.
|
||||||
|
2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg
|
||||||
|
node producing resource output(s). If set, value of _handle_dtypes should
|
||||||
|
contain the dtype(s) of the resource(s) and value of _handle_shapes should
|
||||||
|
contain the shape(s) of the resource(s). If both attributes are set, _Arg
|
||||||
|
node's shape inference function will use their values as the node's output
|
||||||
|
type(s) and shape(s).
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
REGISTER_SYSTEM_OP("_DeviceArg")
|
REGISTER_SYSTEM_OP("_DeviceArg")
|
||||||
|
Loading…
Reference in New Issue
Block a user