diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index ffc2b0248ee..76b9ab8c876 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -686,7 +686,11 @@ StatusOr> MakeFakeArguments(HloModule* const module, std::vector arguments(params.size()); for (int i = 0; i < params.size(); ++i) { const HloModuleConfig& module_config = module->config(); - const Shape& param_shape = module_config.has_entry_computation_layout() + const Shape& param_shape = (module_config.has_entry_computation_layout() && + module_config.entry_computation_layout() + .parameter_layout(i) + .shape() + .is_static()) ? module_config.entry_computation_layout() .parameter_layout(i) .shape()