Ensure that shapes are set for captured inputs to functions. Small eager tensors are captured as const ops, which don't have a shape attr.
PiperOrigin-RevId: 332927858 Change-Id: Ic1ee86e2e5c7db6332d1b989580bdb599c382604
This commit is contained in:
parent
65139ecc00
commit
906fac3188
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_node_util.h"
|
||||
@ -442,6 +443,14 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
|
||||
AttrValue value;
|
||||
*(value.mutable_list()->add_shape()) = attr.second.shape();
|
||||
arg_attrs.mutable_attr()->insert({"_output_shapes", value});
|
||||
} else if (attr.first == "value" && node->type_string() == "Const") {
|
||||
// Small eager tensors are captured as const ops rather than
|
||||
// Placeholders. Add a _output_shapes arg_attr with the shape of the
|
||||
// const tensor.
|
||||
AttrValue value;
|
||||
*(value.mutable_list()->add_shape()) =
|
||||
attr.second.tensor().tensor_shape();
|
||||
arg_attrs.mutable_attr()->insert({"_output_shapes", value});
|
||||
}
|
||||
if (attr.first == "_resource_arg_unique_id") {
|
||||
resource_arg_unique_id = attr.second.i();
|
||||
|
@ -130,7 +130,8 @@ NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
|
||||
|
||||
// We need to insert our argument before the placeholders, which are the last
|
||||
// arguments.
|
||||
OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) {
|
||||
OpDef_ArgDef* InsertSeedArgument(FunctionDef* function, int num_placeholders) {
|
||||
OpDef* signature = function->mutable_signature();
|
||||
int new_argument_idx = signature->input_arg_size() - num_placeholders;
|
||||
signature->add_input_arg();
|
||||
for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) {
|
||||
@ -140,6 +141,16 @@ OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) {
|
||||
seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx));
|
||||
seed_arg->set_type(DT_INT64);
|
||||
|
||||
// Update arg_attr, any arg_attrs for the placeholders how have index one
|
||||
// higher.
|
||||
for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) {
|
||||
if (function->arg_attr().contains(i - 1)) {
|
||||
(*function->mutable_arg_attr())[i] =
|
||||
(*function->mutable_arg_attr())[i - 1];
|
||||
function->mutable_arg_attr()->erase(i - 1);
|
||||
}
|
||||
}
|
||||
|
||||
return seed_arg;
|
||||
}
|
||||
|
||||
@ -157,8 +168,7 @@ const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function,
|
||||
graph_utils::SetUniqueGraphFunctionName("stateless_function", library,
|
||||
stateless_function);
|
||||
|
||||
auto* seed_arg = InsertSeedArgument(stateless_function->mutable_signature(),
|
||||
num_placeholders);
|
||||
auto* seed_arg = InsertSeedArgument(stateless_function, num_placeholders);
|
||||
|
||||
auto* const random_uniform = stateless_function->mutable_node_def(
|
||||
function_utils::FindFunctionNodeWithOp("RandomUniform",
|
||||
|
Loading…
Reference in New Issue
Block a user