Bug fix: create arguments based on module config's entry computation parameter shape only if it's static.

PiperOrigin-RevId: 349593500
Change-Id: Ice8b57c7c85a5b3457f3bd25bc3926a48d41c32e
This commit is contained in:
Jinliang Wei 2020-12-30 13:11:29 -08:00 committed by TensorFlower Gardener
parent ae6ff84595
commit 64c87d65ff

View File

@ -686,7 +686,11 @@ StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
std::vector<Literal> arguments(params.size()); std::vector<Literal> arguments(params.size());
for (int i = 0; i < params.size(); ++i) { for (int i = 0; i < params.size(); ++i) {
const HloModuleConfig& module_config = module->config(); const HloModuleConfig& module_config = module->config();
const Shape& param_shape = module_config.has_entry_computation_layout() const Shape& param_shape = (module_config.has_entry_computation_layout() &&
module_config.entry_computation_layout()
.parameter_layout(i)
.shape()
.is_static())
? module_config.entry_computation_layout() ? module_config.entry_computation_layout()
.parameter_layout(i) .parameter_layout(i)
.shape() .shape()