Enable a few test targets for tfrt. Disable quantization test since we don't plan to have quantization support in the initial launch.
PiperOrigin-RevId: 335143411 Change-Id: I606bacf12bd9b349da304cd97a8acc081dc758f0
This commit is contained in:
parent
3b0672c24b
commit
b1109ff545
@ -1810,6 +1810,7 @@ cuda_py_test(
|
||||
name = "bitcast_op_test",
|
||||
size = "small",
|
||||
srcs = ["bitcast_op_test.py"],
|
||||
tfrt_enabled = True,
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -1841,6 +1842,7 @@ cuda_py_test(
|
||||
name = "constant_op_test",
|
||||
size = "small",
|
||||
srcs = ["constant_op_test.py"],
|
||||
tfrt_enabled = True,
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -2096,6 +2098,7 @@ cuda_py_test(
|
||||
name = "dynamic_stitch_op_test",
|
||||
size = "small",
|
||||
srcs = ["dynamic_stitch_op_test.py"],
|
||||
tfrt_enabled = True,
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:data_flow_grad",
|
||||
@ -2683,6 +2686,7 @@ cuda_py_test(
|
||||
"no_windows",
|
||||
"no_windows_gpu",
|
||||
],
|
||||
tfrt_enabled = True,
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -82,6 +82,7 @@ class BitcastTest(test.TestCase):
|
||||
datatype = dtypes.int8
|
||||
array_ops.bitcast(x, datatype, None)
|
||||
|
||||
@test_util.disable_tfrt("b/169901260")
|
||||
def testQuantizedType(self):
|
||||
shape = [3, 4]
|
||||
x = np.zeros(shape, np.uint16)
|
||||
|
@ -456,6 +456,7 @@ class ZerosTest(test.TestCase):
|
||||
self.assertFalse(np.any(z_value))
|
||||
self.assertEqual((2, 3), z_value.shape)
|
||||
|
||||
@test_util.disable_tfrt("b/169901260")
|
||||
def testQint8Dtype(self):
|
||||
dtype = dtypes_lib.qint8
|
||||
z = array_ops.zeros([2, 3], dtype=dtype)
|
||||
@ -466,6 +467,7 @@ class ZerosTest(test.TestCase):
|
||||
z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32))
|
||||
self.assertFalse(np.any(z_value))
|
||||
|
||||
@test_util.disable_tfrt("b/169901260")
|
||||
def testQint16Dtype(self):
|
||||
dtype = dtypes_lib.qint16
|
||||
z = array_ops.zeros([2, 3], dtype=dtype)
|
||||
@ -650,6 +652,7 @@ class OnesTest(test.TestCase):
|
||||
self.assertEqual([2, 3], z.get_shape())
|
||||
self.assertAllEqual(z, np.ones([2, 3]))
|
||||
|
||||
@test_util.disable_tfrt("b/169901260")
|
||||
def testQintDtype(self):
|
||||
|
||||
@def_function.function(autograph=False)
|
||||
|
@ -991,6 +991,7 @@ class ComparisonOpTest(test.TestCase):
|
||||
[[True, True, True, True, True], [False, False, False, False, False]],
|
||||
values)
|
||||
|
||||
@test_util.disable_tfrt("b/169901260")
|
||||
def testEqualQuantizeDType(self):
|
||||
dtypes = [
|
||||
dtypes_lib.qint8,
|
||||
|
@ -62,6 +62,7 @@ class DynamicStitchTestBase(object):
|
||||
# length.
|
||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||
|
||||
@test_util.disable_tfrt("b/169901260")
|
||||
def testSimpleOneDimensional(self):
|
||||
# Test various datatypes in the simple case to ensure that the op was
|
||||
# registered under those types.
|
||||
|
@ -309,6 +309,7 @@ class SpaceToDepthTest(test.TestCase):
|
||||
actual_vals, expected_vals = self.evaluate([actual, expected])
|
||||
self.assertTrue(np.array_equal(actual_vals, expected_vals))
|
||||
|
||||
@test_util.disable_tfrt("b/169901260")
|
||||
def testAgainstTranspose(self):
|
||||
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", dtypes.float32, False)
|
||||
self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", dtypes.float32, False)
|
||||
|
Loading…
Reference in New Issue
Block a user