Internal cleanup

PiperOrigin-RevId: 171053770
This commit is contained in:
A. Unique TensorFlower 2017-10-04 13:26:47 -07:00 committed by TensorFlower Gardener
parent cc8ee6c0f5
commit e7c53698e0
4 changed files with 27 additions and 19 deletions

View File

@ -168,27 +168,31 @@ def make_tensor(v, arg_name):
def args_to_matching_eager(l, ctx, default_dtype=None):
"""Convert sequence `l` to eager same-type Tensors."""
EagerTensor = ops.EagerTensor # pylint: disable=invalid-name
if all(isinstance(x, EagerTensor) for x in l):
return l[0].dtype, l
# TODO(josh11b): Could we do a better job if we also passed in the
# allowed dtypes when that was known?
# Is some input already a Tensor with a dtype?
dtype = None
for t in l:
if isinstance(t, ops.EagerTensor):
if isinstance(t, EagerTensor):
dtype = t.dtype
break
internal_convert_to_tensor = ops.internal_convert_to_tensor
if dtype is None:
# Infer a dtype based on the first value, and use that dtype for the
# remaining values.
ret = []
for t in l:
ret.append(ops.internal_convert_to_tensor(
ret.append(internal_convert_to_tensor(
t, dtype, preferred_dtype=default_dtype, ctx=ctx))
if dtype is None:
dtype = ret[-1].dtype
else:
ret = [ops.internal_convert_to_tensor(t, dtype, ctx=ctx) for t in l]
ret = [internal_convert_to_tensor(t, dtype, ctx=ctx) for t in l]
return dtype, ret

View File

@ -112,8 +112,10 @@ class Layer(object):
self._per_input_losses = {}
self._per_input_updates = {}
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
self._compute_previous_mask = ('mask' in estimator_util.fn_args(self.call)
or hasattr(self, 'compute_mask'))
call_fn_args = estimator_util.fn_args(self.call)
self._compute_previous_mask = ('mask' in call_fn_args or
hasattr(self, 'compute_mask'))
self._call_has_scope_arg = 'scope' in call_fn_args
# These lists will be filled via successive calls
# to self._add_inbound_node().
@ -555,7 +557,15 @@ class Layer(object):
self.build(input_shapes[0])
else:
self.build(input_shapes)
if 'scope' in estimator_util.fn_args(self.call):
try:
# Note: not all sub-classes of Layer call Layer.__init__ (especially
# the ones under tensorflow/python/keras). Hence we recompute this
# attribute here if it is not set.
# TODO(agarwal): Fix the sub-classes and avoid this complexity.
call_has_scope_arg = self._call_has_scope_arg
except AttributeError:
call_has_scope_arg = 'scope' in estimator_util.fn_args(self.call)
if call_has_scope_arg:
kwargs['scope'] = scope
# Check input assumptions set after layer building, e.g. input shape.
if in_graph_mode:
@ -1433,8 +1443,10 @@ class Network(Layer):
self._activity_regularizer = None
self._scope = next(vs.variable_scope(None, default_name=base_name).gen)
self._base_name = base_name
self._compute_previous_mask = ('mask' in estimator_util.fn_args(self.call)
or hasattr(self, 'compute_mask'))
call_fn_args = estimator_util.fn_args(self.call)
self._compute_previous_mask = ('mask' in call_fn_args or
hasattr(self, 'compute_mask'))
self._call_has_scope_arg = 'scope' in call_fn_args
# This acts just like the `trainable` attribute of any layer instance.
# It does not affect users of the underlying layers, only users of the

View File

@ -330,7 +330,7 @@ class BatchNormalization(base.Layer):
lambda: self._one_minus_decay,
lambda: 0.)
else:
one_minus_decay = self._one_minus_decay
one_minus_decay = ops.convert_to_tensor(self._one_minus_decay)
if training_value or training_value is None:
mean_update = self._assign_moving_average(self.moving_mean, mean,
one_minus_decay)

View File

@ -540,17 +540,9 @@ class ResourceVariable(variables.Variable):
the read operation.
"""
with ops.name_scope("Read"):
# In graph mode, ensure we read the variable in the same device as the
# handle. In eager mode, however, this sometimes tries to read a GPU
# variable in the CPU because the handle is host memory. For now, then, we
# need to skip the device block in eager. TODO(apassos): eager should have
# separate notions of device and memory, so handle.device can be GPU while
# handle.memory_space is always CPU.
if context.in_graph_mode():
# Ensure we read the variable in the same device as the handle.
with ops.device(self._handle_device):
value = self._read_variable_op()
else:
value = self._read_variable_op()
# Return an identity so it can get placed on whatever device the context
# specifies instead of the device where the variable is.
return array_ops.identity(value)