Fix shape inference call in compile op
PiperOrigin-RevId: 337246090 Change-Id: I7053df6c4d7bb61244ad0b0081904156bffa975a
This commit is contained in:
parent
5452c25097
commit
19200d4ec5
@ -514,8 +514,7 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
|
||||
// Converts the GraphShapeInfo into the form needed by the constant-folding
|
||||
// pass of the optimizer.
|
||||
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
|
||||
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
|
||||
metadata, arg_shapes, graph->get(), flr, &shape_info));
|
||||
ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
|
||||
optimizer_opts.shape_map = &shape_map;
|
||||
optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user