Removes "fast paths" which are not fast in eager mode.

PiperOrigin-RevId: 168046278
This commit is contained in:
Alexandre Passos 2017-09-08 14:26:09 -07:00 committed by TensorFlower Gardener
parent 86f1713e51
commit f331f528b8

View File

@ -1428,6 +1428,8 @@ def zeros(shape, dtype=dtypes.float32, name=None):
zero = ""
else:
zero = 0
if context.in_eager_mode():
return fill(shape, constant(zero, dtype=dtype), name=name)
try:
shape = tensor_shape.as_shape(shape)
output = constant(zero, shape=shape, dtype=dtype, name=name)
@ -1466,6 +1468,13 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
with ops.name_scope(name, "zeros_like", [tensor]) as name:
tensor = ops.convert_to_tensor(tensor, name="tensor")
if context.in_eager_mode():
if dtype is not None and dtype != tensor.dtype:
return zeros(
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
with ops.device(tensor.device):
return gen_array_ops._zeros_like(tensor, name=name)
# For now, variant types must be created via zeros_like; as we need to
# pass the input variant object to the proper zeros callback.