Fix incorrect "is" comparisons to instead use "==".
PiperOrigin-RevId: 237140147
This commit is contained in:
parent
f193648584
commit
18722841f5
@ -112,7 +112,7 @@ class VirtualBatchnormTest(test.TestCase):
|
|||||||
batch, axis, training=True)
|
batch, axis, training=True)
|
||||||
|
|
||||||
# Get VBN's batch normalization on reference batch.
|
# Get VBN's batch normalization on reference batch.
|
||||||
batch_axis = 0 if axis is not 0 else 1 # axis and batch_axis can't same
|
batch_axis = 0 if axis != 0 else 1 # axis and batch_axis can't same
|
||||||
vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis)
|
vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis)
|
||||||
vbn_normalized = vbn.reference_batch_normalization()
|
vbn_normalized = vbn.reference_batch_normalization()
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ def new_data_to_array(fn):
|
|||||||
vals = []
|
vals = []
|
||||||
with open(fn) as f:
|
with open(fn) as f:
|
||||||
for n, line in enumerate(f):
|
for n, line in enumerate(f):
|
||||||
if n is not 0:
|
if n != 0:
|
||||||
vals.extend([int(v, 16) for v in line.split()])
|
vals.extend([int(v, 16) for v in line.split()])
|
||||||
b = ''.join(map(chr, vals))
|
b = ''.join(map(chr, vals))
|
||||||
y = struct.unpack('<' + 'h' * int(len(b) / 2), b)
|
y = struct.unpack('<' + 'h' * int(len(b) / 2), b)
|
||||||
|
@ -30,7 +30,7 @@ def new_data_to_array(fn, datatype='int16'):
|
|||||||
vals = []
|
vals = []
|
||||||
with open(fn) as f:
|
with open(fn) as f:
|
||||||
for n, line in enumerate(f):
|
for n, line in enumerate(f):
|
||||||
if n is not 0:
|
if n != 0:
|
||||||
vals.extend([int(v, 16) for v in line.split()])
|
vals.extend([int(v, 16) for v in line.split()])
|
||||||
b = ''.join(map(chr, vals))
|
b = ''.join(map(chr, vals))
|
||||||
|
|
||||||
|
@ -157,7 +157,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
for i in range(34):
|
for i in range(34):
|
||||||
self.assertEqual(i * 3, self.evaluate(next_element()))
|
self.assertEqual(i * 3, self.evaluate(next_element()))
|
||||||
summary_str = self.evaluate(aggregator.get_summary())
|
summary_str = self.evaluate(aggregator.get_summary())
|
||||||
if i is not 0:
|
if i != 0:
|
||||||
self._assertSummaryHasScalarValue(
|
self._assertSummaryHasScalarValue(
|
||||||
summary_str,
|
summary_str,
|
||||||
self.regexForNodeName("FilterDataset", "dropped_elements"),
|
self.regexForNodeName("FilterDataset", "dropped_elements"),
|
||||||
|
@ -319,7 +319,7 @@ def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker):
|
|||||||
devices = []
|
devices = []
|
||||||
for task_type in ("chief", "worker"):
|
for task_type in ("chief", "worker"):
|
||||||
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
|
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
|
||||||
if num_gpus_per_worker is 0:
|
if num_gpus_per_worker == 0:
|
||||||
devices.append("/job:%s/task:%d" % (task_type, task_id))
|
devices.append("/job:%s/task:%d" % (task_type, task_id))
|
||||||
else:
|
else:
|
||||||
devices.extend([
|
devices.extend([
|
||||||
|
@ -124,7 +124,7 @@ class CostAnalysisTest(test.TestCase):
|
|||||||
op_count = int(m.group(1))
|
op_count = int(m.group(1))
|
||||||
# upper = int(m.group(5))
|
# upper = int(m.group(5))
|
||||||
lower = int(m.group(6))
|
lower = int(m.group(6))
|
||||||
if op_type is b"MatMul":
|
if op_type == b"MatMul":
|
||||||
self.assertEqual(3, op_count)
|
self.assertEqual(3, op_count)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(1, op_count)
|
self.assertEqual(1, op_count)
|
||||||
|
@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging
|
|||||||
class SoftmaxTest(test.TestCase):
|
class SoftmaxTest(test.TestCase):
|
||||||
|
|
||||||
def _npSoftmax(self, features, dim=-1, log=False):
|
def _npSoftmax(self, features, dim=-1, log=False):
|
||||||
if dim is -1:
|
if dim == -1:
|
||||||
dim = len(features.shape) - 1
|
dim = len(features.shape) - 1
|
||||||
one_only_on_dim = list(features.shape)
|
one_only_on_dim = list(features.shape)
|
||||||
one_only_on_dim[dim] = 1
|
one_only_on_dim[dim] = 1
|
||||||
|
@ -41,7 +41,7 @@ from tensorflow.python.platform import test
|
|||||||
class XentTest(test.TestCase):
|
class XentTest(test.TestCase):
|
||||||
|
|
||||||
def _npXent(self, features, labels, dim=-1):
|
def _npXent(self, features, labels, dim=-1):
|
||||||
if dim is -1:
|
if dim == -1:
|
||||||
dim = len(features.shape) - 1
|
dim = len(features.shape) - 1
|
||||||
one_only_on_dim = list(features.shape)
|
one_only_on_dim = list(features.shape)
|
||||||
one_only_on_dim[dim] = 1
|
one_only_on_dim[dim] = 1
|
||||||
|
@ -154,7 +154,7 @@ def _get_pad_shape(params, indices):
|
|||||||
if params.shape.ndims == indices.shape.ndims:
|
if params.shape.ndims == indices.shape.ndims:
|
||||||
# When params and indices are the same rank, the shape of the pad tensor is
|
# When params and indices are the same rank, the shape of the pad tensor is
|
||||||
# almost identical to params, except the last dimension which has size = 1.
|
# almost identical to params, except the last dimension which has size = 1.
|
||||||
if params_shape.num_inner_dimensions is 0:
|
if params_shape.num_inner_dimensions == 0:
|
||||||
pad_dims = params_shape.partitioned_dim_sizes[:-1] + (
|
pad_dims = params_shape.partitioned_dim_sizes[:-1] + (
|
||||||
array_ops.ones_like(params_shape.partitioned_dim_sizes[-1]),)
|
array_ops.ones_like(params_shape.partitioned_dim_sizes[-1]),)
|
||||||
return ragged_tensor_shape.RaggedTensorDynamicShape(
|
return ragged_tensor_shape.RaggedTensorDynamicShape(
|
||||||
|
@ -67,7 +67,7 @@ def get_plugin_asset(plugin_asset_cls, graph=None):
|
|||||||
name = _PLUGIN_ASSET_PREFIX + plugin_asset_cls.plugin_name
|
name = _PLUGIN_ASSET_PREFIX + plugin_asset_cls.plugin_name
|
||||||
container = graph.get_collection(name)
|
container = graph.get_collection(name)
|
||||||
if container:
|
if container:
|
||||||
if len(container) is not 1:
|
if len(container) != 1:
|
||||||
raise ValueError("Collection for %s had %d items, expected 1" %
|
raise ValueError("Collection for %s had %d items, expected 1" %
|
||||||
(name, len(container)))
|
(name, len(container)))
|
||||||
instance = container[0]
|
instance = container[0]
|
||||||
@ -102,7 +102,7 @@ def get_all_plugin_assets(graph=None):
|
|||||||
out = []
|
out = []
|
||||||
for name in graph.get_collection(_PLUGIN_ASSET_PREFIX):
|
for name in graph.get_collection(_PLUGIN_ASSET_PREFIX):
|
||||||
collection = graph.get_collection(_PLUGIN_ASSET_PREFIX + name)
|
collection = graph.get_collection(_PLUGIN_ASSET_PREFIX + name)
|
||||||
if len(collection) is not 1:
|
if len(collection) != 1:
|
||||||
raise ValueError("Collection for %s had %d items, expected 1" %
|
raise ValueError("Collection for %s had %d items, expected 1" %
|
||||||
(name, len(collection)))
|
(name, len(collection)))
|
||||||
out.append(collection[0])
|
out.append(collection[0])
|
||||||
|
Loading…
Reference in New Issue
Block a user