Makes some ref-related dtype methods private.
Change: 139484060
This commit is contained in:
parent
7d0573f0d7
commit
d535017fc2
@ -382,7 +382,7 @@ class NodeStepper(object):
|
|||||||
# Determine whether the input is feedable. Reference-type tensors,
|
# Determine whether the input is feedable. Reference-type tensors,
|
||||||
# e.g., Variables, should not be fed, because they can change.
|
# e.g., Variables, should not be fed, because they can change.
|
||||||
if isinstance(inp, ops.Tensor):
|
if isinstance(inp, ops.Tensor):
|
||||||
is_inp_ref = inp.dtype.is_ref_dtype
|
is_inp_ref = inp.dtype._is_ref_dtype # pylint: disable=protected-access
|
||||||
can_feed = self._sess.graph.is_feedable(inp) and not is_inp_ref
|
can_feed = self._sess.graph.is_feedable(inp) and not is_inp_ref
|
||||||
else:
|
else:
|
||||||
is_inp_ref = False
|
is_inp_ref = False
|
||||||
|
@ -59,8 +59,6 @@ class DType(object):
|
|||||||
@@name
|
@@name
|
||||||
@@base_dtype
|
@@base_dtype
|
||||||
@@real_dtype
|
@@real_dtype
|
||||||
@@is_ref_dtype
|
|
||||||
@@as_ref
|
|
||||||
@@is_floating
|
@@is_floating
|
||||||
@@is_complex
|
@@is_complex
|
||||||
@@is_integer
|
@@is_integer
|
||||||
@ -97,14 +95,14 @@ class DType(object):
|
|||||||
self._type_enum = type_enum
|
self._type_enum = type_enum
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_ref_dtype(self):
|
def _is_ref_dtype(self):
|
||||||
"""Returns `True` if this `DType` represents a reference type."""
|
"""Returns `True` if this `DType` represents a reference type."""
|
||||||
return self._type_enum > 100
|
return self._type_enum > 100
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def as_ref(self):
|
def _as_ref(self):
|
||||||
"""Returns a reference `DType` based on this `DType`."""
|
"""Returns a reference `DType` based on this `DType`."""
|
||||||
if self.is_ref_dtype:
|
if self._is_ref_dtype:
|
||||||
return self
|
return self
|
||||||
else:
|
else:
|
||||||
return _INTERN_TABLE[self._type_enum + 100]
|
return _INTERN_TABLE[self._type_enum + 100]
|
||||||
@ -112,7 +110,7 @@ class DType(object):
|
|||||||
@property
|
@property
|
||||||
def base_dtype(self):
|
def base_dtype(self):
|
||||||
"""Returns a non-reference `DType` based on this `DType`."""
|
"""Returns a non-reference `DType` based on this `DType`."""
|
||||||
if self.is_ref_dtype:
|
if self._is_ref_dtype:
|
||||||
return _INTERN_TABLE[self._type_enum - 100]
|
return _INTERN_TABLE[self._type_enum - 100]
|
||||||
else:
|
else:
|
||||||
return self
|
return self
|
||||||
@ -269,7 +267,7 @@ class DType(object):
|
|||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
dtype = as_dtype(other).as_datatype_enum
|
dtype = as_dtype(other).as_datatype_enum
|
||||||
return self._type_enum == dtype
|
return self._type_enum == dtype # pylint: disable=protected-access
|
||||||
except TypeError:
|
except TypeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ def _ArgToTypesNoRef(node_def, arg_def):
|
|||||||
def _SingleArgToTypes(node_def, arg_def):
|
def _SingleArgToTypes(node_def, arg_def):
|
||||||
types = _ArgToTypesNoRef(node_def, arg_def)
|
types = _ArgToTypesNoRef(node_def, arg_def)
|
||||||
if arg_def.is_ref:
|
if arg_def.is_ref:
|
||||||
return [dtypes.as_dtype(dt).as_ref.as_datatype_enum for dt in types]
|
return [dtypes.as_dtype(dt)._as_ref.as_datatype_enum for dt in types] # pylint: disable=protected-access
|
||||||
return types
|
return types
|
||||||
|
|
||||||
|
|
||||||
|
@ -718,7 +718,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
|||||||
# We'll use the following device function to observe ops with two inputs.
|
# We'll use the following device function to observe ops with two inputs.
|
||||||
ops_with_two_inputs = []
|
ops_with_two_inputs = []
|
||||||
def input_counter(op):
|
def input_counter(op):
|
||||||
if any(in_t.dtype.is_ref_dtype for in_t in op.inputs):
|
if any(in_t.dtype._is_ref_dtype for in_t in op.inputs): # pylint: disable=protected-access
|
||||||
ops_with_two_inputs.append(op)
|
ops_with_two_inputs.append(op)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@ -607,7 +607,7 @@ class OpDefLibrary(object):
|
|||||||
assert False, "Unreachable"
|
assert False, "Unreachable"
|
||||||
|
|
||||||
if input_arg.is_ref:
|
if input_arg.is_ref:
|
||||||
if not all(x.is_ref_dtype for x in types):
|
if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Input '%s' of '%s' Op requires l-value input" %
|
"Input '%s' of '%s' Op requires l-value input" %
|
||||||
(input_name, op_type_name))
|
(input_name, op_type_name))
|
||||||
@ -741,7 +741,7 @@ class OpDefLibrary(object):
|
|||||||
types = [arg.type]
|
types = [arg.type]
|
||||||
output_structure.append(None)
|
output_structure.append(None)
|
||||||
if arg.is_ref:
|
if arg.is_ref:
|
||||||
types = [dtypes.as_dtype(x).as_ref for x in types]
|
types = [dtypes.as_dtype(x)._as_ref for x in types] # pylint: disable=protected-access
|
||||||
output_types.extend(types)
|
output_types.extend(types)
|
||||||
|
|
||||||
if keywords:
|
if keywords:
|
||||||
|
@ -160,7 +160,7 @@ def _Identity(data, name=None):
|
|||||||
"""
|
"""
|
||||||
data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
|
data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
|
||||||
if isinstance(data, ops.Tensor):
|
if isinstance(data, ops.Tensor):
|
||||||
if data.dtype.is_ref_dtype:
|
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
|
||||||
return gen_array_ops._ref_identity(data, name=name)
|
return gen_array_ops._ref_identity(data, name=name)
|
||||||
else:
|
else:
|
||||||
return array_ops.identity(data, name=name)
|
return array_ops.identity(data, name=name)
|
||||||
@ -182,7 +182,7 @@ def _Identity(data, name=None):
|
|||||||
def _NextIteration(data, name=None):
|
def _NextIteration(data, name=None):
|
||||||
data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
|
data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
|
||||||
if isinstance(data, ops.Tensor):
|
if isinstance(data, ops.Tensor):
|
||||||
if data.dtype.is_ref_dtype:
|
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
|
||||||
return ref_next_iteration(data, name=name)
|
return ref_next_iteration(data, name=name)
|
||||||
else:
|
else:
|
||||||
return next_iteration(data, name=name)
|
return next_iteration(data, name=name)
|
||||||
@ -223,7 +223,7 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
|
|||||||
"""
|
"""
|
||||||
data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
|
data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
|
||||||
if isinstance(data, ops.Tensor):
|
if isinstance(data, ops.Tensor):
|
||||||
if data.dtype.is_ref_dtype and use_ref:
|
if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access
|
||||||
result = ref_enter(data, frame_name, is_constant, parallel_iterations,
|
result = ref_enter(data, frame_name, is_constant, parallel_iterations,
|
||||||
name=name)
|
name=name)
|
||||||
else:
|
else:
|
||||||
@ -272,7 +272,7 @@ def exit(data, name=None):
|
|||||||
"""
|
"""
|
||||||
data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
|
data = ops.convert_to_tensor_or_indexed_slices(data, as_ref=True)
|
||||||
if isinstance(data, ops.Tensor):
|
if isinstance(data, ops.Tensor):
|
||||||
if data.dtype.is_ref_dtype:
|
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
|
||||||
return gen_control_flow_ops._ref_exit(data, name)
|
return gen_control_flow_ops._ref_exit(data, name)
|
||||||
else:
|
else:
|
||||||
return gen_control_flow_ops._exit(data, name)
|
return gen_control_flow_ops._exit(data, name)
|
||||||
@ -378,7 +378,7 @@ def _SwitchRefOrTensor(data, pred, name="Switch"):
|
|||||||
# created within ops.colocate_with(data) to ignore the existing stack.
|
# created within ops.colocate_with(data) to ignore the existing stack.
|
||||||
with ops.colocate_with(data, ignore_existing=True):
|
with ops.colocate_with(data, ignore_existing=True):
|
||||||
if isinstance(data, ops.Tensor):
|
if isinstance(data, ops.Tensor):
|
||||||
if data.dtype.is_ref_dtype:
|
if data.dtype._is_ref_dtype: # pylint: disable=protected-access
|
||||||
return ref_switch(data, pred, name=name)
|
return ref_switch(data, pred, name=name)
|
||||||
return switch(data, pred, name=name)
|
return switch(data, pred, name=name)
|
||||||
|
|
||||||
@ -414,7 +414,7 @@ def merge(inputs, name=None):
|
|||||||
inputs = [ops.convert_to_tensor_or_indexed_slices(inp, as_ref=True)
|
inputs = [ops.convert_to_tensor_or_indexed_slices(inp, as_ref=True)
|
||||||
for inp in inputs]
|
for inp in inputs]
|
||||||
if all([isinstance(v, ops.Tensor) for v in inputs]):
|
if all([isinstance(v, ops.Tensor) for v in inputs]):
|
||||||
if all([v.dtype.is_ref_dtype for v in inputs]):
|
if all([v.dtype._is_ref_dtype for v in inputs]): # pylint: disable=protected-access
|
||||||
return gen_control_flow_ops._ref_merge(inputs, name)
|
return gen_control_flow_ops._ref_merge(inputs, name)
|
||||||
else:
|
else:
|
||||||
return gen_control_flow_ops._merge(inputs, name)
|
return gen_control_flow_ops._merge(inputs, name)
|
||||||
|
Loading…
Reference in New Issue
Block a user