Removes "fast paths" which are not fast in eager mode.
PiperOrigin-RevId: 168046278
This commit is contained in:
parent
86f1713e51
commit
f331f528b8
@ -1428,6 +1428,8 @@ def zeros(shape, dtype=dtypes.float32, name=None):
|
||||
zero = ""
|
||||
else:
|
||||
zero = 0
|
||||
if context.in_eager_mode():
|
||||
return fill(shape, constant(zero, dtype=dtype), name=name)
|
||||
try:
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
output = constant(zero, shape=shape, dtype=dtype, name=name)
|
||||
@ -1466,6 +1468,13 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
|
||||
with ops.name_scope(name, "zeros_like", [tensor]) as name:
|
||||
tensor = ops.convert_to_tensor(tensor, name="tensor")
|
||||
|
||||
if context.in_eager_mode():
|
||||
if dtype is not None and dtype != tensor.dtype:
|
||||
return zeros(
|
||||
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
|
||||
with ops.device(tensor.device):
|
||||
return gen_array_ops._zeros_like(tensor, name=name)
|
||||
|
||||
# For now, variant types must be created via zeros_like; as we need to
|
||||
# pass the input variant object to the proper zeros callback.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user