Internal cleanup
PiperOrigin-RevId: 171053770
This commit is contained in:
parent
cc8ee6c0f5
commit
e7c53698e0
@ -168,27 +168,31 @@ def make_tensor(v, arg_name):
|
||||
|
||||
def args_to_matching_eager(l, ctx, default_dtype=None):
|
||||
"""Convert sequence `l` to eager same-type Tensors."""
|
||||
EagerTensor = ops.EagerTensor # pylint: disable=invalid-name
|
||||
if all(isinstance(x, EagerTensor) for x in l):
|
||||
return l[0].dtype, l
|
||||
# TODO(josh11b): Could we do a better job if we also passed in the
|
||||
# allowed dtypes when that was known?
|
||||
|
||||
# Is some input already a Tensor with a dtype?
|
||||
dtype = None
|
||||
for t in l:
|
||||
if isinstance(t, ops.EagerTensor):
|
||||
if isinstance(t, EagerTensor):
|
||||
dtype = t.dtype
|
||||
break
|
||||
|
||||
internal_convert_to_tensor = ops.internal_convert_to_tensor
|
||||
if dtype is None:
|
||||
# Infer a dtype based on the first value, and use that dtype for the
|
||||
# remaining values.
|
||||
ret = []
|
||||
for t in l:
|
||||
ret.append(ops.internal_convert_to_tensor(
|
||||
ret.append(internal_convert_to_tensor(
|
||||
t, dtype, preferred_dtype=default_dtype, ctx=ctx))
|
||||
if dtype is None:
|
||||
dtype = ret[-1].dtype
|
||||
else:
|
||||
ret = [ops.internal_convert_to_tensor(t, dtype, ctx=ctx) for t in l]
|
||||
ret = [internal_convert_to_tensor(t, dtype, ctx=ctx) for t in l]
|
||||
|
||||
return dtype, ret
|
||||
|
||||
|
@ -112,8 +112,10 @@ class Layer(object):
|
||||
self._per_input_losses = {}
|
||||
self._per_input_updates = {}
|
||||
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
|
||||
self._compute_previous_mask = ('mask' in estimator_util.fn_args(self.call)
|
||||
or hasattr(self, 'compute_mask'))
|
||||
call_fn_args = estimator_util.fn_args(self.call)
|
||||
self._compute_previous_mask = ('mask' in call_fn_args or
|
||||
hasattr(self, 'compute_mask'))
|
||||
self._call_has_scope_arg = 'scope' in call_fn_args
|
||||
|
||||
# These lists will be filled via successive calls
|
||||
# to self._add_inbound_node().
|
||||
@ -555,7 +557,15 @@ class Layer(object):
|
||||
self.build(input_shapes[0])
|
||||
else:
|
||||
self.build(input_shapes)
|
||||
if 'scope' in estimator_util.fn_args(self.call):
|
||||
try:
|
||||
# Note: not all sub-classes of Layer call Layer.__init__ (especially
|
||||
# the ones under tensorflow/python/keras). Hence we recompute this
|
||||
# attribute here if it is not set.
|
||||
# TODO(agarwal): Fix the sub-classes and avoid this complexity.
|
||||
call_has_scope_arg = self._call_has_scope_arg
|
||||
except AttributeError:
|
||||
call_has_scope_arg = 'scope' in estimator_util.fn_args(self.call)
|
||||
if call_has_scope_arg:
|
||||
kwargs['scope'] = scope
|
||||
# Check input assumptions set after layer building, e.g. input shape.
|
||||
if in_graph_mode:
|
||||
@ -1433,8 +1443,10 @@ class Network(Layer):
|
||||
self._activity_regularizer = None
|
||||
self._scope = next(vs.variable_scope(None, default_name=base_name).gen)
|
||||
self._base_name = base_name
|
||||
self._compute_previous_mask = ('mask' in estimator_util.fn_args(self.call)
|
||||
or hasattr(self, 'compute_mask'))
|
||||
call_fn_args = estimator_util.fn_args(self.call)
|
||||
self._compute_previous_mask = ('mask' in call_fn_args or
|
||||
hasattr(self, 'compute_mask'))
|
||||
self._call_has_scope_arg = 'scope' in call_fn_args
|
||||
|
||||
# This acts just like the `trainable` attribute of any layer instance.
|
||||
# It does not affect users of the underlying layers, only users of the
|
||||
|
@ -330,7 +330,7 @@ class BatchNormalization(base.Layer):
|
||||
lambda: self._one_minus_decay,
|
||||
lambda: 0.)
|
||||
else:
|
||||
one_minus_decay = self._one_minus_decay
|
||||
one_minus_decay = ops.convert_to_tensor(self._one_minus_decay)
|
||||
if training_value or training_value is None:
|
||||
mean_update = self._assign_moving_average(self.moving_mean, mean,
|
||||
one_minus_decay)
|
||||
|
@ -540,17 +540,9 @@ class ResourceVariable(variables.Variable):
|
||||
the read operation.
|
||||
"""
|
||||
with ops.name_scope("Read"):
|
||||
# In graph mode, ensure we read the variable in the same device as the
|
||||
# handle. In eager mode, however, this sometimes tries to read a GPU
|
||||
# variable in the CPU because the handle is host memory. For now, then, we
|
||||
# need to skip the device block in eager. TODO(apassos): eager should have
|
||||
# separate notions of device and memory, so handle.device can be GPU while
|
||||
# handle.memory_space is always CPU.
|
||||
if context.in_graph_mode():
|
||||
# Ensure we read the variable in the same device as the handle.
|
||||
with ops.device(self._handle_device):
|
||||
value = self._read_variable_op()
|
||||
else:
|
||||
value = self._read_variable_op()
|
||||
# Return an identity so it can get placed on whatever device the context
|
||||
# specifies instead of the device where the variable is.
|
||||
return array_ops.identity(value)
|
||||
|
Loading…
Reference in New Issue
Block a user