Fix incorrect "is" comparisons to instead use "==".
PiperOrigin-RevId: 237140147
This commit is contained in:
parent
f193648584
commit
18722841f5
@ -112,7 +112,7 @@ class VirtualBatchnormTest(test.TestCase):
|
||||
batch, axis, training=True)
|
||||
|
||||
# Get VBN's batch normalization on reference batch.
|
||||
batch_axis = 0 if axis is not 0 else 1 # axis and batch_axis can't same
|
||||
batch_axis = 0 if axis != 0 else 1 # axis and batch_axis can't same
|
||||
vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis)
|
||||
vbn_normalized = vbn.reference_batch_normalization()
|
||||
|
||||
|
@ -28,7 +28,7 @@ def new_data_to_array(fn):
|
||||
vals = []
|
||||
with open(fn) as f:
|
||||
for n, line in enumerate(f):
|
||||
if n is not 0:
|
||||
if n != 0:
|
||||
vals.extend([int(v, 16) for v in line.split()])
|
||||
b = ''.join(map(chr, vals))
|
||||
y = struct.unpack('<' + 'h' * int(len(b) / 2), b)
|
||||
|
@ -30,7 +30,7 @@ def new_data_to_array(fn, datatype='int16'):
|
||||
vals = []
|
||||
with open(fn) as f:
|
||||
for n, line in enumerate(f):
|
||||
if n is not 0:
|
||||
if n != 0:
|
||||
vals.extend([int(v, 16) for v in line.split()])
|
||||
b = ''.join(map(chr, vals))
|
||||
|
||||
|
@ -157,7 +157,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
for i in range(34):
|
||||
self.assertEqual(i * 3, self.evaluate(next_element()))
|
||||
summary_str = self.evaluate(aggregator.get_summary())
|
||||
if i is not 0:
|
||||
if i != 0:
|
||||
self._assertSummaryHasScalarValue(
|
||||
summary_str,
|
||||
self.regexForNodeName("FilterDataset", "dropped_elements"),
|
||||
|
@ -319,7 +319,7 @@ def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker):
|
||||
devices = []
|
||||
for task_type in ("chief", "worker"):
|
||||
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
|
||||
if num_gpus_per_worker is 0:
|
||||
if num_gpus_per_worker == 0:
|
||||
devices.append("/job:%s/task:%d" % (task_type, task_id))
|
||||
else:
|
||||
devices.extend([
|
||||
|
@ -124,7 +124,7 @@ class CostAnalysisTest(test.TestCase):
|
||||
op_count = int(m.group(1))
|
||||
# upper = int(m.group(5))
|
||||
lower = int(m.group(6))
|
||||
if op_type is b"MatMul":
|
||||
if op_type == b"MatMul":
|
||||
self.assertEqual(3, op_count)
|
||||
else:
|
||||
self.assertEqual(1, op_count)
|
||||
|
@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
class SoftmaxTest(test.TestCase):
|
||||
|
||||
def _npSoftmax(self, features, dim=-1, log=False):
|
||||
if dim is -1:
|
||||
if dim == -1:
|
||||
dim = len(features.shape) - 1
|
||||
one_only_on_dim = list(features.shape)
|
||||
one_only_on_dim[dim] = 1
|
||||
|
@ -41,7 +41,7 @@ from tensorflow.python.platform import test
|
||||
class XentTest(test.TestCase):
|
||||
|
||||
def _npXent(self, features, labels, dim=-1):
|
||||
if dim is -1:
|
||||
if dim == -1:
|
||||
dim = len(features.shape) - 1
|
||||
one_only_on_dim = list(features.shape)
|
||||
one_only_on_dim[dim] = 1
|
||||
|
@ -154,7 +154,7 @@ def _get_pad_shape(params, indices):
|
||||
if params.shape.ndims == indices.shape.ndims:
|
||||
# When params and indices are the same rank, the shape of the pad tensor is
|
||||
# almost identical to params, except the last dimension which has size = 1.
|
||||
if params_shape.num_inner_dimensions is 0:
|
||||
if params_shape.num_inner_dimensions == 0:
|
||||
pad_dims = params_shape.partitioned_dim_sizes[:-1] + (
|
||||
array_ops.ones_like(params_shape.partitioned_dim_sizes[-1]),)
|
||||
return ragged_tensor_shape.RaggedTensorDynamicShape(
|
||||
|
@ -67,7 +67,7 @@ def get_plugin_asset(plugin_asset_cls, graph=None):
|
||||
name = _PLUGIN_ASSET_PREFIX + plugin_asset_cls.plugin_name
|
||||
container = graph.get_collection(name)
|
||||
if container:
|
||||
if len(container) is not 1:
|
||||
if len(container) != 1:
|
||||
raise ValueError("Collection for %s had %d items, expected 1" %
|
||||
(name, len(container)))
|
||||
instance = container[0]
|
||||
@ -102,7 +102,7 @@ def get_all_plugin_assets(graph=None):
|
||||
out = []
|
||||
for name in graph.get_collection(_PLUGIN_ASSET_PREFIX):
|
||||
collection = graph.get_collection(_PLUGIN_ASSET_PREFIX + name)
|
||||
if len(collection) is not 1:
|
||||
if len(collection) != 1:
|
||||
raise ValueError("Collection for %s had %d items, expected 1" %
|
||||
(name, len(collection)))
|
||||
out.append(collection[0])
|
||||
|
Loading…
Reference in New Issue
Block a user