Fix TPUStrategy with numpy input in TF2.0.

PiperOrigin-RevId: 248773260
This commit is contained in:
Ruoxin Sang 2019-05-17 13:03:46 -07:00 committed by TensorFlower Gardener
parent 5a97851506
commit 39e91628f5

View File

@ -21,6 +21,8 @@ from __future__ import print_function
import collections
import copy
import numpy as np
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
@ -34,6 +36,7 @@ from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -185,7 +188,9 @@ def _tpu_run(strategy, fn, args, kwargs):
maximum_shapes = []
flattened_list = nest.flatten(replicate_inputs[0])
for input_tensor in flattened_list:
maximum_shapes.append(input_tensor.get_shape())
maximum_shape = input_tensor.get_shape() if tensor_util.is_tensor(
input_tensor) else tensor_shape.TensorShape(np.shape(input_tensor))
maximum_shapes.append(maximum_shape)
maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
maximum_shapes)
else: