Enable a few test targets for tfrt. Disable quantization test since we don't plan to have quantization support in the initial launch.

PiperOrigin-RevId: 335143411
Change-Id: I606bacf12bd9b349da304cd97a8acc081dc758f0
This commit is contained in:
Xiao Yu 2020-10-02 19:22:09 -07:00 committed by TensorFlower Gardener
parent 3b0672c24b
commit b1109ff545
6 changed files with 11 additions and 0 deletions

View File

@ -1810,6 +1810,7 @@ cuda_py_test(
name = "bitcast_op_test",
size = "small",
srcs = ["bitcast_op_test.py"],
tfrt_enabled = True,
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@ -1841,6 +1842,7 @@ cuda_py_test(
name = "constant_op_test",
size = "small",
srcs = ["constant_op_test.py"],
tfrt_enabled = True,
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@ -2096,6 +2098,7 @@ cuda_py_test(
name = "dynamic_stitch_op_test",
size = "small",
srcs = ["dynamic_stitch_op_test.py"],
tfrt_enabled = True,
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:data_flow_grad",
@ -2683,6 +2686,7 @@ cuda_py_test(
"no_windows",
"no_windows_gpu",
],
tfrt_enabled = True,
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",

View File

@ -82,6 +82,7 @@ class BitcastTest(test.TestCase):
datatype = dtypes.int8
array_ops.bitcast(x, datatype, None)
@test_util.disable_tfrt("b/169901260")
def testQuantizedType(self):
shape = [3, 4]
x = np.zeros(shape, np.uint16)

View File

@ -456,6 +456,7 @@ class ZerosTest(test.TestCase):
self.assertFalse(np.any(z_value))
self.assertEqual((2, 3), z_value.shape)
@test_util.disable_tfrt("b/169901260")
def testQint8Dtype(self):
dtype = dtypes_lib.qint8
z = array_ops.zeros([2, 3], dtype=dtype)
@ -466,6 +467,7 @@ class ZerosTest(test.TestCase):
z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32))
self.assertFalse(np.any(z_value))
@test_util.disable_tfrt("b/169901260")
def testQint16Dtype(self):
dtype = dtypes_lib.qint16
z = array_ops.zeros([2, 3], dtype=dtype)
@ -650,6 +652,7 @@ class OnesTest(test.TestCase):
self.assertEqual([2, 3], z.get_shape())
self.assertAllEqual(z, np.ones([2, 3]))
@test_util.disable_tfrt("b/169901260")
def testQintDtype(self):
@def_function.function(autograph=False)

View File

@ -991,6 +991,7 @@ class ComparisonOpTest(test.TestCase):
[[True, True, True, True, True], [False, False, False, False, False]],
values)
@test_util.disable_tfrt("b/169901260")
def testEqualQuantizeDType(self):
dtypes = [
dtypes_lib.qint8,

View File

@ -62,6 +62,7 @@ class DynamicStitchTestBase(object):
# length.
self.assertEqual([None], stitched_t.get_shape().as_list())
@test_util.disable_tfrt("b/169901260")
def testSimpleOneDimensional(self):
# Test various datatypes in the simple case to ensure that the op was
# registered under those types.

View File

@ -309,6 +309,7 @@ class SpaceToDepthTest(test.TestCase):
actual_vals, expected_vals = self.evaluate([actual, expected])
self.assertTrue(np.array_equal(actual_vals, expected_vals))
@test_util.disable_tfrt("b/169901260")
def testAgainstTranspose(self):
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", dtypes.float32, False)
self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", dtypes.float32, False)