Fix TPUStrategy with numpy input in TF2.0.
PiperOrigin-RevId: 248773260
This commit is contained in:
parent
5a97851506
commit
39e91628f5
@ -21,6 +21,8 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
@ -34,6 +36,7 @@ from tensorflow.python.eager import tape
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -185,7 +188,9 @@ def _tpu_run(strategy, fn, args, kwargs):
|
|||||||
maximum_shapes = []
|
maximum_shapes = []
|
||||||
flattened_list = nest.flatten(replicate_inputs[0])
|
flattened_list = nest.flatten(replicate_inputs[0])
|
||||||
for input_tensor in flattened_list:
|
for input_tensor in flattened_list:
|
||||||
maximum_shapes.append(input_tensor.get_shape())
|
maximum_shape = input_tensor.get_shape() if tensor_util.is_tensor(
|
||||||
|
input_tensor) else tensor_shape.TensorShape(np.shape(input_tensor))
|
||||||
|
maximum_shapes.append(maximum_shape)
|
||||||
maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
|
maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
|
||||||
maximum_shapes)
|
maximum_shapes)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user