Merge pull request #30508 from I-Hong:test_deflt_val

PiperOrigin-RevId: 257229251
This commit is contained in:
TensorFlower Gardener 2019-07-09 11:22:40 -07:00
commit 348fde8095

View File

@ -265,11 +265,13 @@ class NumericColumnTest(test.TestCase):
self.assertEqual(((3., 2.),), a.default_value)
def test_shape_and_default_value_compatibility(self):
fc.numeric_column('aaa', shape=[2], default_value=[1, 2.])
a = fc.numeric_column('aaa', shape=[2], default_value=[1, 2.])
self.assertEqual((1, 2.), a.default_value)
with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
fc.numeric_column('aaa', shape=[2], default_value=[1, 2, 3.])
fc.numeric_column(
'aaa', shape=[3, 2], default_value=[[2, 3], [1, 2], [2, 3.]])
a = fc.numeric_column(
'aaa', shape=[3, 2], default_value=[[2, 3], [1, 2], [2, 3.]])
self.assertEqual(((2, 3), (1, 2), (2, 3.)), a.default_value)
with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
fc.numeric_column(
'aaa', shape=[3, 1], default_value=[[2, 3], [1, 2], [2, 3.]])
@ -858,8 +860,11 @@ class HashedCategoricalColumnTest(test.TestCase):
fc.categorical_column_with_hash_bucket('aaa', 0)
def test_dtype_should_be_string_or_integer(self):
fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.string)
fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32)
a = fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.string)
b = fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32)
self.assertEqual(dtypes.string, a.dtype)
self.assertEqual(dtypes.int32, b.dtype)
with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.float32)