Removes "fast paths" which are not fast in eager mode.
PiperOrigin-RevId: 168046278
This commit is contained in:
parent
86f1713e51
commit
f331f528b8
@ -1428,6 +1428,8 @@ def zeros(shape, dtype=dtypes.float32, name=None):
|
|||||||
zero = ""
|
zero = ""
|
||||||
else:
|
else:
|
||||||
zero = 0
|
zero = 0
|
||||||
|
if context.in_eager_mode():
|
||||||
|
return fill(shape, constant(zero, dtype=dtype), name=name)
|
||||||
try:
|
try:
|
||||||
shape = tensor_shape.as_shape(shape)
|
shape = tensor_shape.as_shape(shape)
|
||||||
output = constant(zero, shape=shape, dtype=dtype, name=name)
|
output = constant(zero, shape=shape, dtype=dtype, name=name)
|
||||||
@ -1466,6 +1468,13 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
|
|||||||
with ops.name_scope(name, "zeros_like", [tensor]) as name:
|
with ops.name_scope(name, "zeros_like", [tensor]) as name:
|
||||||
tensor = ops.convert_to_tensor(tensor, name="tensor")
|
tensor = ops.convert_to_tensor(tensor, name="tensor")
|
||||||
|
|
||||||
|
if context.in_eager_mode():
|
||||||
|
if dtype is not None and dtype != tensor.dtype:
|
||||||
|
return zeros(
|
||||||
|
shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
|
||||||
|
with ops.device(tensor.device):
|
||||||
|
return gen_array_ops._zeros_like(tensor, name=name)
|
||||||
|
|
||||||
# For now, variant types must be created via zeros_like; as we need to
|
# For now, variant types must be created via zeros_like; as we need to
|
||||||
# pass the input variant object to the proper zeros callback.
|
# pass the input variant object to the proper zeros callback.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user