Fix incorrect "is" comparisons to instead use "==".

PiperOrigin-RevId: 237140147
This commit is contained in:
A. Unique TensorFlower 2019-03-06 16:10:16 -08:00 committed by TensorFlower Gardener
parent f193648584
commit 18722841f5
10 changed files with 11 additions and 11 deletions

View File

@ -112,7 +112,7 @@ class VirtualBatchnormTest(test.TestCase):
batch, axis, training=True) batch, axis, training=True)
# Get VBN's batch normalization on reference batch. # Get VBN's batch normalization on reference batch.
batch_axis = 0 if axis is not 0 else 1 # axis and batch_axis can't same batch_axis = 0 if axis != 0 else 1 # axis and batch_axis can't same
vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis) vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis)
vbn_normalized = vbn.reference_batch_normalization() vbn_normalized = vbn.reference_batch_normalization()

View File

@ -28,7 +28,7 @@ def new_data_to_array(fn):
vals = [] vals = []
with open(fn) as f: with open(fn) as f:
for n, line in enumerate(f): for n, line in enumerate(f):
if n is not 0: if n != 0:
vals.extend([int(v, 16) for v in line.split()]) vals.extend([int(v, 16) for v in line.split()])
b = ''.join(map(chr, vals)) b = ''.join(map(chr, vals))
y = struct.unpack('<' + 'h' * int(len(b) / 2), b) y = struct.unpack('<' + 'h' * int(len(b) / 2), b)

View File

@ -30,7 +30,7 @@ def new_data_to_array(fn, datatype='int16'):
vals = [] vals = []
with open(fn) as f: with open(fn) as f:
for n, line in enumerate(f): for n, line in enumerate(f):
if n is not 0: if n != 0:
vals.extend([int(v, 16) for v in line.split()]) vals.extend([int(v, 16) for v in line.split()])
b = ''.join(map(chr, vals)) b = ''.join(map(chr, vals))

View File

@ -157,7 +157,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
for i in range(34): for i in range(34):
self.assertEqual(i * 3, self.evaluate(next_element())) self.assertEqual(i * 3, self.evaluate(next_element()))
summary_str = self.evaluate(aggregator.get_summary()) summary_str = self.evaluate(aggregator.get_summary())
if i is not 0: if i != 0:
self._assertSummaryHasScalarValue( self._assertSummaryHasScalarValue(
summary_str, summary_str,
self.regexForNodeName("FilterDataset", "dropped_elements"), self.regexForNodeName("FilterDataset", "dropped_elements"),

View File

@ -319,7 +319,7 @@ def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker):
devices = [] devices = []
for task_type in ("chief", "worker"): for task_type in ("chief", "worker"):
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
if num_gpus_per_worker is 0: if num_gpus_per_worker == 0:
devices.append("/job:%s/task:%d" % (task_type, task_id)) devices.append("/job:%s/task:%d" % (task_type, task_id))
else: else:
devices.extend([ devices.extend([

View File

@ -124,7 +124,7 @@ class CostAnalysisTest(test.TestCase):
op_count = int(m.group(1)) op_count = int(m.group(1))
# upper = int(m.group(5)) # upper = int(m.group(5))
lower = int(m.group(6)) lower = int(m.group(6))
if op_type is b"MatMul": if op_type == b"MatMul":
self.assertEqual(3, op_count) self.assertEqual(3, op_count)
else: else:
self.assertEqual(1, op_count) self.assertEqual(1, op_count)

View File

@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging
class SoftmaxTest(test.TestCase): class SoftmaxTest(test.TestCase):
def _npSoftmax(self, features, dim=-1, log=False): def _npSoftmax(self, features, dim=-1, log=False):
if dim is -1: if dim == -1:
dim = len(features.shape) - 1 dim = len(features.shape) - 1
one_only_on_dim = list(features.shape) one_only_on_dim = list(features.shape)
one_only_on_dim[dim] = 1 one_only_on_dim[dim] = 1

View File

@ -41,7 +41,7 @@ from tensorflow.python.platform import test
class XentTest(test.TestCase): class XentTest(test.TestCase):
def _npXent(self, features, labels, dim=-1): def _npXent(self, features, labels, dim=-1):
if dim is -1: if dim == -1:
dim = len(features.shape) - 1 dim = len(features.shape) - 1
one_only_on_dim = list(features.shape) one_only_on_dim = list(features.shape)
one_only_on_dim[dim] = 1 one_only_on_dim[dim] = 1

View File

@ -154,7 +154,7 @@ def _get_pad_shape(params, indices):
if params.shape.ndims == indices.shape.ndims: if params.shape.ndims == indices.shape.ndims:
# When params and indices are the same rank, the shape of the pad tensor is # When params and indices are the same rank, the shape of the pad tensor is
# almost identical to params, except the last dimension which has size = 1. # almost identical to params, except the last dimension which has size = 1.
if params_shape.num_inner_dimensions is 0: if params_shape.num_inner_dimensions == 0:
pad_dims = params_shape.partitioned_dim_sizes[:-1] + ( pad_dims = params_shape.partitioned_dim_sizes[:-1] + (
array_ops.ones_like(params_shape.partitioned_dim_sizes[-1]),) array_ops.ones_like(params_shape.partitioned_dim_sizes[-1]),)
return ragged_tensor_shape.RaggedTensorDynamicShape( return ragged_tensor_shape.RaggedTensorDynamicShape(

View File

@ -67,7 +67,7 @@ def get_plugin_asset(plugin_asset_cls, graph=None):
name = _PLUGIN_ASSET_PREFIX + plugin_asset_cls.plugin_name name = _PLUGIN_ASSET_PREFIX + plugin_asset_cls.plugin_name
container = graph.get_collection(name) container = graph.get_collection(name)
if container: if container:
if len(container) is not 1: if len(container) != 1:
raise ValueError("Collection for %s had %d items, expected 1" % raise ValueError("Collection for %s had %d items, expected 1" %
(name, len(container))) (name, len(container)))
instance = container[0] instance = container[0]
@ -102,7 +102,7 @@ def get_all_plugin_assets(graph=None):
out = [] out = []
for name in graph.get_collection(_PLUGIN_ASSET_PREFIX): for name in graph.get_collection(_PLUGIN_ASSET_PREFIX):
collection = graph.get_collection(_PLUGIN_ASSET_PREFIX + name) collection = graph.get_collection(_PLUGIN_ASSET_PREFIX + name)
if len(collection) is not 1: if len(collection) != 1:
raise ValueError("Collection for %s had %d items, expected 1" % raise ValueError("Collection for %s had %d items, expected 1" %
(name, len(collection))) (name, len(collection)))
out.append(collection[0]) out.append(collection[0])