diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 4d55d939b72..8434e0394f4 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -21,6 +21,8 @@ from __future__ import print_function import collections import copy +import numpy as np + from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib @@ -34,6 +36,7 @@ from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -185,7 +188,9 @@ def _tpu_run(strategy, fn, args, kwargs): maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: - maximum_shapes.append(input_tensor.get_shape()) + maximum_shape = input_tensor.get_shape() if tensor_util.is_tensor( + input_tensor) else tensor_shape.TensorShape(np.shape(input_tensor)) + maximum_shapes.append(maximum_shape) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: