Fix shape inference call in compile op
PiperOrigin-RevId: 337246090 Change-Id: I7053df6c4d7bb61244ad0b0081904156bffa975a
This commit is contained in:
parent
5452c25097
commit
19200d4ec5
@ -514,8 +514,7 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
|
|||||||
// Converts the GraphShapeInfo into the form needed by the constant-folding
|
// Converts the GraphShapeInfo into the form needed by the constant-folding
|
||||||
// pass of the optimizer.
|
// pass of the optimizer.
|
||||||
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
|
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
|
||||||
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
|
ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
|
||||||
metadata, arg_shapes, graph->get(), flr, &shape_info));
|
|
||||||
optimizer_opts.shape_map = &shape_map;
|
optimizer_opts.shape_map = &shape_map;
|
||||||
optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
|
optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user