_Arg node requires _handle_dtypes and _handle_shapes attr when T is DT_RESOURCE.
Otherwise, function node with DT_RESOURCE type input fails to infer output shapes. PiperOrigin-RevId: 308927584 Change-Id: I653b293aaa36ffd4a6bed68e4692398eebd9d578
This commit is contained in:
parent
62856c9366
commit
7028dbefc0
@ -804,14 +804,47 @@ class SymbolicShapeRefiner {
|
|||||||
int output_port_num = input_tensor.index();
|
int output_port_num = input_tensor.index();
|
||||||
AttrValue attr_output_shape;
|
AttrValue attr_output_shape;
|
||||||
TensorShapeProto proto;
|
TensorShapeProto proto;
|
||||||
const auto& handle = input_ic->output(output_port_num);
|
const auto handle = input_ic->output(output_port_num);
|
||||||
input_ic->ShapeHandleToProto(handle, &proto);
|
input_ic->ShapeHandleToProto(handle, &proto);
|
||||||
// There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
|
// There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
|
||||||
NormalizeShapeForOutput(&proto);
|
NormalizeShapeForOutput(&proto);
|
||||||
|
// _Arg op's output shape uses _output_shapes attr.
|
||||||
AttrValue output_attr;
|
AttrValue output_attr;
|
||||||
output_attr.mutable_list()->add_shape()->Swap(&proto);
|
output_attr.mutable_list()->add_shape()->Swap(&proto);
|
||||||
(*fun_node->mutable_attr())["_output_shapes"] = output_attr;
|
(*fun_node->mutable_attr())["_output_shapes"] = output_attr;
|
||||||
|
|
||||||
|
// If dtype is DT_RESOURCE, ops that read _Arg op use _handle_dtypes and
|
||||||
|
// _handle_shapes attr for its shapes and dtypes.
|
||||||
|
if (fun_input.data_type == DT_RESOURCE) {
|
||||||
|
auto* shapes_and_types =
|
||||||
|
input_ic->output_handle_shapes_and_types(output_port_num);
|
||||||
|
if (shapes_and_types != nullptr && !shapes_and_types->empty()) {
|
||||||
|
AttrValue dtype_attr;
|
||||||
|
AttrValue shape_attr;
|
||||||
|
for (const auto& shape_and_type : *shapes_and_types) {
|
||||||
|
const auto& dtype = shape_and_type.dtype;
|
||||||
|
const auto& shape_handle = shape_and_type.shape;
|
||||||
|
dtype_attr.mutable_list()->add_type(dtype);
|
||||||
|
input_ic->ShapeHandleToProto(
|
||||||
|
shape_handle, shape_attr.mutable_list()->add_shape());
|
||||||
|
}
|
||||||
|
(*fun_node->mutable_attr())["_handle_dtypes"] = dtype_attr;
|
||||||
|
(*fun_node->mutable_attr())["_handle_shapes"] = shape_attr;
|
||||||
|
} else {
|
||||||
|
// Note that we do not return error here, even if the input node does
|
||||||
|
// not have shapes_and_types. Within the function, we cannot infer the
|
||||||
|
// output shape of the DT_RESOURCE input; hence, potentially unknown
|
||||||
|
// shapes/dims in the function output shapes.
|
||||||
|
VLOG(2)
|
||||||
|
<< "A function node (" << function_node->name()
|
||||||
|
<< ") has input with DT_RESOURCE, but the input node does not "
|
||||||
|
<< "have shapes_and_types information: \n"
|
||||||
|
<< "function_node: " << function_node->ShortDebugString() << "\n"
|
||||||
|
<< "function input: " << i
|
||||||
|
<< ", input node's output: " << output_port_num << "\n"
|
||||||
|
<< "input node: " << input_node->ShortDebugString();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace input nodes with Consts, if values are known. Note that
|
// Replace input nodes with Consts, if values are known. Note that
|
||||||
@ -820,7 +853,7 @@ class SymbolicShapeRefiner {
|
|||||||
auto* ic = ctx->inference_context.get();
|
auto* ic = ctx->inference_context.get();
|
||||||
for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
|
for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
|
||||||
const string& input = function_node->input(i);
|
const string& input = function_node->input(i);
|
||||||
const string& node_name = NodeName(input);
|
const string node_name = NodeName(input);
|
||||||
const NodeDef* input_node = graph_.GetNode(node_name);
|
const NodeDef* input_node = graph_.GetNode(node_name);
|
||||||
if (IsConstant(*input_node)) {
|
if (IsConstant(*input_node)) {
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
@ -993,9 +1026,11 @@ class SymbolicShapeRefiner {
|
|||||||
// Convert all kUnknownDimFromConst to -1 for shape inference.
|
// Convert all kUnknownDimFromConst to -1 for shape inference.
|
||||||
ic->set_input_tensors_as_shapes(ReplaceUnknownDimFromConstWithUnknownDim(
|
ic->set_input_tensors_as_shapes(ReplaceUnknownDimFromConstWithUnknownDim(
|
||||||
ic, ctx->input_tensors_as_shapes_to_propagate));
|
ic, ctx->input_tensors_as_shapes_to_propagate));
|
||||||
// Notice: UpdateFunction only uses input_tensors_as_shapes, so for function
|
// Note: UpdateFunction uses input_tensors_as_shapes and
|
||||||
// nodes, we dont' perform the conversion from TensorProtos to Tensors for
|
// input_tensor_protos (not the Tensor object) for input values.
|
||||||
// constant inputs here.
|
// so for function nodes, we don't need to convert TensorProtos
|
||||||
|
// to Tensors here. If the current op is not a function op, we convert
|
||||||
|
// TensorProtos to Tensors before calling InferShapes.
|
||||||
|
|
||||||
// Properly handle function nodes.
|
// Properly handle function nodes.
|
||||||
if (ctx->op_data && ctx->op_data->is_function_op) {
|
if (ctx->op_data && ctx->op_data->is_function_op) {
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,130 @@
|
|||||||
|
node {
|
||||||
|
name: "x"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_FLOAT
|
||||||
|
tensor_shape {
|
||||||
|
dim {
|
||||||
|
size: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float_val: 2.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "y"
|
||||||
|
op: "_Arg"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_RESOURCE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "_handle_dtypes"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "_handle_shapes"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "index"
|
||||||
|
value {
|
||||||
|
i: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "FunctionWithDtResourceInput"
|
||||||
|
op: "FunctionWithDtResourceInput"
|
||||||
|
input: "x"
|
||||||
|
input: "y"
|
||||||
|
}
|
||||||
|
library {
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "FunctionWithDtResourceInput"
|
||||||
|
input_arg {
|
||||||
|
name: "x"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "y"
|
||||||
|
type: DT_RESOURCE
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "z1"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "z2"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "y1"
|
||||||
|
op: "ReadVariableOp"
|
||||||
|
input: "y"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "Add"
|
||||||
|
op: "Add"
|
||||||
|
input: "x"
|
||||||
|
input: "y1:value:0"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret {
|
||||||
|
key: "z1"
|
||||||
|
value: "Add:z:0"
|
||||||
|
}
|
||||||
|
ret {
|
||||||
|
key: "z2"
|
||||||
|
value: "x"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versions {
|
||||||
|
producer: 26
|
||||||
|
min_consumer: 12
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user