Fix TPUStrategy with numpy input in TF2.0.
PiperOrigin-RevId: 248773260
This commit is contained in:
parent
5a97851506
commit
39e91628f5
@ -21,6 +21,8 @@ from __future__ import print_function
|
||||
import collections
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
@ -34,6 +36,7 @@ from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -185,7 +188,9 @@ def _tpu_run(strategy, fn, args, kwargs):
|
||||
maximum_shapes = []
|
||||
flattened_list = nest.flatten(replicate_inputs[0])
|
||||
for input_tensor in flattened_list:
|
||||
maximum_shapes.append(input_tensor.get_shape())
|
||||
maximum_shape = input_tensor.get_shape() if tensor_util.is_tensor(
|
||||
input_tensor) else tensor_shape.TensorShape(np.shape(input_tensor))
|
||||
maximum_shapes.append(maximum_shape)
|
||||
maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
|
||||
maximum_shapes)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user