Makes some ref-related dtype methods private.

Change: 139484060
This commit is contained in:
A. Unique TensorFlower 2016-11-17 11:21:09 -08:00 committed by TensorFlower Gardener
parent 7d0573f0d7
commit d535017fc2
6 changed files with 16 additions and 18 deletions

View File

@ -382,7 +382,7 @@ class NodeStepper(object):
# Determine whether the input is feedable. Reference-type tensors,
# e.g., Variables, should not be fed, because they can change.
if isinstance(inp, ops.Tensor):
is_inp_ref = inp.dtype.is_ref_dtype
is_inp_ref = inp.dtype._is_ref_dtype # pylint: disable=protected-access
can_feed = self._sess.graph.is_feedable(inp) and not is_inp_ref
else:
is_inp_ref = False

View File

@ -59,8 +59,6 @@ class DType(object):
@@name
@@base_dtype
@@real_dtype
@@is_ref_dtype
@@as_ref
@@is_floating
@@is_complex
@@is_integer
@ -97,14 +95,14 @@ class DType(object):
self._type_enum = type_enum
@property
def is_ref_dtype(self):
def _is_ref_dtype(self):
"""Returns `True` if this `DType` represents a reference type."""
return self._type_enum > 100
@property
def as_ref(self):
def _as_ref(self):
"""Returns a reference `DType` based on this `DType`."""
if self.is_ref_dtype:
if self._is_ref_dtype:
return self
else:
return _INTERN_TABLE[self._type_enum + 100]
@ -112,7 +110,7 @@ class DType(object):
@property
def base_dtype(self):
"""Returns a non-reference `DType` based on this `DType`."""
if self.is_ref_dtype:
if self._is_ref_dtype:
return _INTERN_TABLE[self._type_enum - 100]
else:
return self
@ -269,7 +267,7 @@ class DType(object):
return False
try:
dtype = as_dtype(other).as_datatype_enum
return self._type_enum == dtype
return self._type_enum == dtype # pylint: disable=protected-access
except TypeError:
return False

View File

@ -60,7 +60,7 @@ def _ArgToTypesNoRef(node_def, arg_def):
def _SingleArgToTypes(node_def, arg_def):
types = _ArgToTypesNoRef(node_def, arg_def)
if arg_def.is_ref:
return [dtypes.as_dtype(dt).as_ref.as_datatype_enum for dt in types]
return [dtypes.as_dtype(dt)._as_ref.as_datatype_enum for dt in types] # pylint: disable=protected-access
return types

View File

@ -718,7 +718,7 @@ class ImportGraphDefTest(tf.test.TestCase):
# We'll use the following device function to observe ops with two inputs.
ops_with_two_inputs = []
def input_counter(op):
if any(in_t.dtype.is_ref_dtype for in_t in op.inputs):
if any(in_t.dtype._is_ref_dtype for in_t in op.inputs): # pylint: disable=protected-access
ops_with_two_inputs.append(op)
return ""

View File

@ -607,7 +607,7 @@ class OpDefLibrary(object):
assert False, "Unreachable"
if input_arg.is_ref:
if not all(x.is_ref_dtype for x in types):
if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access
raise TypeError(
"Input '%s' of '%s' Op requires l-value input" %
(input_name, op_type_name))
@ -741,7 +741,7 @@ class OpDefLibrary(object):
types = [arg.type]
output_structure.append(None)
if arg.is_ref:
types = [dtypes.as_dtype(x).as_ref for x in types]
types = [dtypes.as_dtype(x)._as_ref for x in types] # pylint: disable=protected-access
output_types.extend(types)
if keywords:

View File

@ -160,7 +160,7 @@ def _Identity(data, name=None):
"""
data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype.is_ref_dtype:
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return gen_array_ops._ref_identity(data, name=name)
else:
return array_ops.identity(data, name=name)
@ -182,7 +182,7 @@ def _Identity(data, name=None):
def _NextIteration(data, name=None):
data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype.is_ref_dtype:
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return ref_next_iteration(data, name=name)
else:
return next_iteration(data, name=name)
@ -223,7 +223,7 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
"""
data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype.is_ref_dtype and use_ref:
if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access
result = ref_enter(data, frame_name, is_constant, parallel_iterations,
name=name)
else:
@ -272,7 +272,7 @@ def exit(data, name=None):
"""
data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype.is_ref_dtype:
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return gen_control_flow_ops._ref_exit(data, name)
else:
return gen_control_flow_ops._exit(data, name)
@ -378,7 +378,7 @@ def _SwitchRefOrTensor(data, pred, name="Switch"):
# created within ops.colocate_with(data) to ignore the existing stack.
with ops.colocate_with(data, ignore_existing=True):
if isinstance(data, ops.Tensor):
if data.dtype.is_ref_dtype:
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return ref_switch(data, pred, name=name)
return switch(data, pred, name=name)
@ -414,7 +414,7 @@ def merge(inputs, name=None):
inputs = [ops.convert_to_tensor_or_indexed_slices(inp, as_ref=True)
for inp in inputs]
if all([isinstance(v, ops.Tensor) for v in inputs]):
if all([v.dtype.is_ref_dtype for v in inputs]):
if all([v.dtype._is_ref_dtype for v in inputs]): # pylint: disable=protected-access
return gen_control_flow_ops._ref_merge(inputs, name)
else:
return gen_control_flow_ops._merge(inputs, name)