Fix shape inference call in compile op

PiperOrigin-RevId: 337246090
Change-Id: I7053df6c4d7bb61244ad0b0081904156bffa975a
This commit is contained in:
Yuanzhong Xu 2020-10-14 23:01:29 -07:00 committed by TensorFlower Gardener
parent 5452c25097
commit 19200d4ec5

View File

@ -514,8 +514,7 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
// Converts the GraphShapeInfo into the form needed by the constant-folding
// pass of the optimizer.
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
metadata, arg_shapes, graph->get(), flr, &shape_info));
ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
optimizer_opts.shape_map = &shape_map;
optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
}