Internal cleanup
PiperOrigin-RevId: 171053770
This commit is contained in:
parent
cc8ee6c0f5
commit
e7c53698e0
@ -168,27 +168,31 @@ def make_tensor(v, arg_name):
|
|||||||
|
|
||||||
def args_to_matching_eager(l, ctx, default_dtype=None):
|
def args_to_matching_eager(l, ctx, default_dtype=None):
|
||||||
"""Convert sequence `l` to eager same-type Tensors."""
|
"""Convert sequence `l` to eager same-type Tensors."""
|
||||||
|
EagerTensor = ops.EagerTensor # pylint: disable=invalid-name
|
||||||
|
if all(isinstance(x, EagerTensor) for x in l):
|
||||||
|
return l[0].dtype, l
|
||||||
# TODO(josh11b): Could we do a better job if we also passed in the
|
# TODO(josh11b): Could we do a better job if we also passed in the
|
||||||
# allowed dtypes when that was known?
|
# allowed dtypes when that was known?
|
||||||
|
|
||||||
# Is some input already a Tensor with a dtype?
|
# Is some input already a Tensor with a dtype?
|
||||||
dtype = None
|
dtype = None
|
||||||
for t in l:
|
for t in l:
|
||||||
if isinstance(t, ops.EagerTensor):
|
if isinstance(t, EagerTensor):
|
||||||
dtype = t.dtype
|
dtype = t.dtype
|
||||||
break
|
break
|
||||||
|
|
||||||
|
internal_convert_to_tensor = ops.internal_convert_to_tensor
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
# Infer a dtype based on the first value, and use that dtype for the
|
# Infer a dtype based on the first value, and use that dtype for the
|
||||||
# remaining values.
|
# remaining values.
|
||||||
ret = []
|
ret = []
|
||||||
for t in l:
|
for t in l:
|
||||||
ret.append(ops.internal_convert_to_tensor(
|
ret.append(internal_convert_to_tensor(
|
||||||
t, dtype, preferred_dtype=default_dtype, ctx=ctx))
|
t, dtype, preferred_dtype=default_dtype, ctx=ctx))
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = ret[-1].dtype
|
dtype = ret[-1].dtype
|
||||||
else:
|
else:
|
||||||
ret = [ops.internal_convert_to_tensor(t, dtype, ctx=ctx) for t in l]
|
ret = [internal_convert_to_tensor(t, dtype, ctx=ctx) for t in l]
|
||||||
|
|
||||||
return dtype, ret
|
return dtype, ret
|
||||||
|
|
||||||
|
|||||||
@ -112,8 +112,10 @@ class Layer(object):
|
|||||||
self._per_input_losses = {}
|
self._per_input_losses = {}
|
||||||
self._per_input_updates = {}
|
self._per_input_updates = {}
|
||||||
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
|
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
|
||||||
self._compute_previous_mask = ('mask' in estimator_util.fn_args(self.call)
|
call_fn_args = estimator_util.fn_args(self.call)
|
||||||
or hasattr(self, 'compute_mask'))
|
self._compute_previous_mask = ('mask' in call_fn_args or
|
||||||
|
hasattr(self, 'compute_mask'))
|
||||||
|
self._call_has_scope_arg = 'scope' in call_fn_args
|
||||||
|
|
||||||
# These lists will be filled via successive calls
|
# These lists will be filled via successive calls
|
||||||
# to self._add_inbound_node().
|
# to self._add_inbound_node().
|
||||||
@ -555,7 +557,15 @@ class Layer(object):
|
|||||||
self.build(input_shapes[0])
|
self.build(input_shapes[0])
|
||||||
else:
|
else:
|
||||||
self.build(input_shapes)
|
self.build(input_shapes)
|
||||||
if 'scope' in estimator_util.fn_args(self.call):
|
try:
|
||||||
|
# Note: not all sub-classes of Layer call Layer.__init__ (especially
|
||||||
|
# the ones under tensorflow/python/keras). Hence we recompute this
|
||||||
|
# attribute here if it is not set.
|
||||||
|
# TODO(agarwal): Fix the sub-classes and avoid this complexity.
|
||||||
|
call_has_scope_arg = self._call_has_scope_arg
|
||||||
|
except AttributeError:
|
||||||
|
call_has_scope_arg = 'scope' in estimator_util.fn_args(self.call)
|
||||||
|
if call_has_scope_arg:
|
||||||
kwargs['scope'] = scope
|
kwargs['scope'] = scope
|
||||||
# Check input assumptions set after layer building, e.g. input shape.
|
# Check input assumptions set after layer building, e.g. input shape.
|
||||||
if in_graph_mode:
|
if in_graph_mode:
|
||||||
@ -1433,8 +1443,10 @@ class Network(Layer):
|
|||||||
self._activity_regularizer = None
|
self._activity_regularizer = None
|
||||||
self._scope = next(vs.variable_scope(None, default_name=base_name).gen)
|
self._scope = next(vs.variable_scope(None, default_name=base_name).gen)
|
||||||
self._base_name = base_name
|
self._base_name = base_name
|
||||||
self._compute_previous_mask = ('mask' in estimator_util.fn_args(self.call)
|
call_fn_args = estimator_util.fn_args(self.call)
|
||||||
or hasattr(self, 'compute_mask'))
|
self._compute_previous_mask = ('mask' in call_fn_args or
|
||||||
|
hasattr(self, 'compute_mask'))
|
||||||
|
self._call_has_scope_arg = 'scope' in call_fn_args
|
||||||
|
|
||||||
# This acts just like the `trainable` attribute of any layer instance.
|
# This acts just like the `trainable` attribute of any layer instance.
|
||||||
# It does not affect users of the underlying layers, only users of the
|
# It does not affect users of the underlying layers, only users of the
|
||||||
|
|||||||
@ -330,7 +330,7 @@ class BatchNormalization(base.Layer):
|
|||||||
lambda: self._one_minus_decay,
|
lambda: self._one_minus_decay,
|
||||||
lambda: 0.)
|
lambda: 0.)
|
||||||
else:
|
else:
|
||||||
one_minus_decay = self._one_minus_decay
|
one_minus_decay = ops.convert_to_tensor(self._one_minus_decay)
|
||||||
if training_value or training_value is None:
|
if training_value or training_value is None:
|
||||||
mean_update = self._assign_moving_average(self.moving_mean, mean,
|
mean_update = self._assign_moving_average(self.moving_mean, mean,
|
||||||
one_minus_decay)
|
one_minus_decay)
|
||||||
|
|||||||
@ -540,16 +540,8 @@ class ResourceVariable(variables.Variable):
|
|||||||
the read operation.
|
the read operation.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope("Read"):
|
with ops.name_scope("Read"):
|
||||||
# In graph mode, ensure we read the variable in the same device as the
|
# Ensure we read the variable in the same device as the handle.
|
||||||
# handle. In eager mode, however, this sometimes tries to read a GPU
|
with ops.device(self._handle_device):
|
||||||
# variable in the CPU because the handle is host memory. For now, then, we
|
|
||||||
# need to skip the device block in eager. TODO(apassos): eager should have
|
|
||||||
# separate notions of device and memory, so handle.device can be GPU while
|
|
||||||
# handle.memory_space is always CPU.
|
|
||||||
if context.in_graph_mode():
|
|
||||||
with ops.device(self._handle_device):
|
|
||||||
value = self._read_variable_op()
|
|
||||||
else:
|
|
||||||
value = self._read_variable_op()
|
value = self._read_variable_op()
|
||||||
# Return an identity so it can get placed on whatever device the context
|
# Return an identity so it can get placed on whatever device the context
|
||||||
# specifies instead of the device where the variable is.
|
# specifies instead of the device where the variable is.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user