Import package xla_test instead of class XLATestCase.
PiperOrigin-RevId: 202572322
This commit is contained in:
parent
c81830af5d
commit
6dc9977e1d
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -28,7 +28,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adagrad
|
||||
|
||||
|
||||
class AdagradOptimizerTest(XLATestCase):
|
||||
class AdagradOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
for dtype in self.float_types:
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -48,7 +48,7 @@ def adam_update_numpy(param,
|
||||
return param_t, m_t, v_t
|
||||
|
||||
|
||||
class AdamOptimizerTest(XLATestCase):
|
||||
class AdamOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
for dtype in self.float_types:
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -32,7 +32,7 @@ from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class BinaryOpsTest(XLATestCase):
|
||||
class BinaryOpsTest(xla_test.XLATestCase):
|
||||
"""Test cases for binary operators."""
|
||||
|
||||
def _testBinary(self, op, a, b, expected, equality_test=None):
|
||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -26,7 +26,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class BucketizationOpTest(XLATestCase):
|
||||
class BucketizationOpTest(xla_test.XLATestCase):
|
||||
|
||||
def testInt(self):
|
||||
with self.test_session() as sess:
|
||||
|
@ -22,7 +22,7 @@ import collections
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
# TODO(srvasude): Merge this with
|
||||
# third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py.
|
||||
class CategoricalTest(XLATestCase):
|
||||
class CategoricalTest(xla_test.XLATestCase):
|
||||
"""Test cases for random-number generating operators."""
|
||||
|
||||
def output_dtypes(self):
|
||||
|
@ -23,7 +23,7 @@ import unittest
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -32,7 +32,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class CholeskyOpTest(XLATestCase):
|
||||
class CholeskyOpTest(xla_test.XLATestCase):
|
||||
|
||||
# Cholesky defined for float64, float32, complex64, complex128
|
||||
# (https://www.tensorflow.org/api_docs/python/tf/cholesky)
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
|
||||
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
|
||||
|
||||
class ClusteringTest(XLATestCase):
|
||||
class ClusteringTest(xla_test.XLATestCase):
|
||||
|
||||
def testAdd(self):
|
||||
val1 = np.array([4, 3, 2, 1], dtype=np.float32)
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class ConcatTest(XLATestCase):
|
||||
class ConcatTest(xla_test.XLATestCase):
|
||||
|
||||
def testHStack(self):
|
||||
with self.test_session():
|
||||
@ -292,7 +292,7 @@ class ConcatTest(XLATestCase):
|
||||
array_ops.concat([scalar, scalar, scalar], dim)
|
||||
|
||||
|
||||
class ConcatOffsetTest(XLATestCase):
|
||||
class ConcatOffsetTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session() as sess:
|
||||
@ -306,7 +306,7 @@ class ConcatOffsetTest(XLATestCase):
|
||||
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
|
||||
|
||||
|
||||
class PackTest(XLATestCase):
|
||||
class PackTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session() as sess:
|
||||
|
@ -26,7 +26,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import test_utils
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
@ -42,7 +42,7 @@ DATA_FORMATS = (
|
||||
)
|
||||
|
||||
|
||||
class Conv2DTest(XLATestCase, parameterized.TestCase):
|
||||
class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _VerifyValues(self,
|
||||
input_sizes=None,
|
||||
@ -236,7 +236,7 @@ class Conv2DTest(XLATestCase, parameterized.TestCase):
|
||||
expected=np.reshape([108, 128], [1, 1, 1, 2]))
|
||||
|
||||
|
||||
class Conv2DBackpropInputTest(XLATestCase, parameterized.TestCase):
|
||||
class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _VerifyValues(self,
|
||||
input_sizes=None,
|
||||
@ -534,7 +534,7 @@ class Conv2DBackpropInputTest(XLATestCase, parameterized.TestCase):
|
||||
expected=[5, 0, 11, 0, 0, 0, 17, 0, 23])
|
||||
|
||||
|
||||
class Conv2DBackpropFilterTest(XLATestCase, parameterized.TestCase):
|
||||
class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _VerifyValues(self,
|
||||
input_sizes=None,
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
# Test cloned from
|
||||
# tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
|
||||
class Conv3DBackpropFilterV2GradTest(XLATestCase):
|
||||
class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase):
|
||||
|
||||
def testGradient(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
@ -66,7 +66,7 @@ class Conv3DBackpropFilterV2GradTest(XLATestCase):
|
||||
|
||||
|
||||
# Test cloned from tensorflow/python/kernel_tests/conv3d_transpose_test.py
|
||||
class Conv3DTransposeTest(XLATestCase):
|
||||
class Conv3DTransposeTest(xla_test.XLATestCase):
|
||||
|
||||
def testConv3DTransposeSingleStride(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -114,7 +114,7 @@ def CheckGradConfigsToTest():
|
||||
yield i, f, o, s, p
|
||||
|
||||
|
||||
class DepthwiseConv2DTest(XLATestCase):
|
||||
class DepthwiseConv2DTest(xla_test.XLATestCase):
|
||||
|
||||
# This is testing that depthwise_conv2d and depthwise_conv2d_native
|
||||
# produce the same results. It also tests that NCHW and NWHC
|
||||
|
@ -20,14 +20,14 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.compiler.tf2xla.python import xla
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class DynamicUpdateSliceOpsTest(XLATestCase):
|
||||
class DynamicUpdateSliceOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def _assertOpOutputMatchesExpected(self, op, args, expected):
|
||||
with self.test_session() as session:
|
||||
|
@ -20,14 +20,14 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class DynamicStitchTest(XLATestCase):
|
||||
class DynamicStitchTest(xla_test.XLATestCase):
|
||||
|
||||
def _AssertDynamicStitchResultIs(self, indices, data, expected):
|
||||
with self.test_session() as session:
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
@ -40,7 +40,7 @@ from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.training import adam
|
||||
|
||||
|
||||
class EagerTest(XLATestCase):
|
||||
class EagerTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_scope():
|
||||
@ -286,7 +286,7 @@ class EagerTest(XLATestCase):
|
||||
[2.0, 2.0]], embedding_matrix.numpy())
|
||||
|
||||
|
||||
class EagerFunctionTest(XLATestCase):
|
||||
class EagerFunctionTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_scope():
|
||||
@ -419,7 +419,7 @@ class EagerFunctionTest(XLATestCase):
|
||||
self.assertAllEqual((2, 3, 4), dz.shape.as_list())
|
||||
|
||||
|
||||
class ExcessivePaddingTest(XLATestCase):
|
||||
class ExcessivePaddingTest(xla_test.XLATestCase):
|
||||
"""Test that eager execution works with TPU flattened tensors.
|
||||
|
||||
Tensors that would normally be excessively padded when written
|
||||
|
@ -20,13 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ExtractImagePatches(XLATestCase):
|
||||
class ExtractImagePatches(xla_test.XLATestCase):
|
||||
"""Functional tests for ExtractImagePatches op."""
|
||||
|
||||
def _VerifyValues(self, image, ksizes, strides, rates, padding, patches):
|
||||
|
@ -17,14 +17,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxArgsTest(XLATestCase):
|
||||
class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase):
|
||||
"""Test cases for FakeQuantWithMinMaxArgs operation."""
|
||||
|
||||
# 8 bits, wide range.
|
||||
@ -122,7 +122,7 @@ class FakeQuantWithMinMaxArgsTest(XLATestCase):
|
||||
result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxArgsGradientTest(XLATestCase):
|
||||
class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase):
|
||||
"""Test cases for FakeQuantWithMinMaxArgsGradient operation."""
|
||||
|
||||
# 8 bits, wide range.
|
||||
@ -223,7 +223,7 @@ class FakeQuantWithMinMaxArgsGradientTest(XLATestCase):
|
||||
bfloat16_rtol=0.03)
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxVarsTest(XLATestCase):
|
||||
class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase):
|
||||
"""Test cases for FakeQuantWithMinMaxVars operation."""
|
||||
|
||||
# 8 bits, wide range.
|
||||
@ -328,7 +328,7 @@ class FakeQuantWithMinMaxVarsTest(XLATestCase):
|
||||
result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxVarsGradientTest(XLATestCase):
|
||||
class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase):
|
||||
"""Test cases for FakeQuantWithMinMaxVarsGradient operation."""
|
||||
|
||||
# 8 bits, wide range.
|
||||
|
@ -23,7 +23,7 @@ import itertools
|
||||
import numpy as np
|
||||
import scipy.signal as sps
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.contrib.signal.python.ops import spectral_ops as signal
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -58,7 +58,7 @@ INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2))
|
||||
INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2))
|
||||
|
||||
|
||||
class FFTTest(XLATestCase):
|
||||
class FFTTest(xla_test.XLATestCase):
|
||||
|
||||
def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected,
|
||||
tf_method):
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.training import ftrl
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
class FtrlOptimizerTest(XLATestCase):
|
||||
class FtrlOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def initVariableAndGradient(self, dtype):
|
||||
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class FunctionTest(XLATestCase):
|
||||
class FunctionTest(xla_test.XLATestCase):
|
||||
|
||||
def testFunction(self):
|
||||
"""Executes a simple TensorFlow function."""
|
||||
|
@ -22,7 +22,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import test_utils
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.ops import nn
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FusedBatchNormTest(XLATestCase, parameterized.TestCase):
|
||||
class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _reference_training(self, x, scale, offset, epsilon, data_format):
|
||||
if data_format != "NHWC":
|
||||
|
@ -20,13 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class GatherNdTest(XLATestCase):
|
||||
class GatherNdTest(xla_test.XLATestCase):
|
||||
|
||||
def _runGather(self, params, indices):
|
||||
with self.test_session():
|
||||
|
@ -25,7 +25,7 @@ import numpy as np
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -41,7 +41,7 @@ def GenerateNumpyRandomRGB(shape):
|
||||
return np.random.randint(0, 256, shape) / 256.
|
||||
|
||||
|
||||
class RGBToHSVTest(XLATestCase):
|
||||
class RGBToHSVTest(xla_test.XLATestCase):
|
||||
|
||||
def testBatch(self):
|
||||
# Build an arbitrary RGB image
|
||||
@ -104,7 +104,7 @@ class RGBToHSVTest(XLATestCase):
|
||||
self.assertAllCloseAccordingToType(hsv_tf, hsv_np)
|
||||
|
||||
|
||||
class AdjustContrastTest(XLATestCase):
|
||||
class AdjustContrastTest(xla_test.XLATestCase):
|
||||
|
||||
def _testContrast(self, x_np, y_np, contrast_factor):
|
||||
with self.test_session():
|
||||
@ -168,7 +168,7 @@ class AdjustContrastTest(XLATestCase):
|
||||
self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
class AdjustHueTest(XLATestCase):
|
||||
class AdjustHueTest(xla_test.XLATestCase):
|
||||
|
||||
def testAdjustNegativeHue(self):
|
||||
x_shape = [2, 2, 3]
|
||||
@ -303,7 +303,7 @@ class AdjustHueTest(XLATestCase):
|
||||
self._adjustHueTf(x_np, delta_h)
|
||||
|
||||
|
||||
class AdjustSaturationTest(XLATestCase):
|
||||
class AdjustSaturationTest(xla_test.XLATestCase):
|
||||
|
||||
def _adjust_saturation(self, image, saturation_factor):
|
||||
image = ops.convert_to_tensor(image, name="image")
|
||||
@ -403,7 +403,7 @@ class AdjustSaturationTest(XLATestCase):
|
||||
self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5)
|
||||
|
||||
|
||||
class ResizeBilinearTest(XLATestCase):
|
||||
class ResizeBilinearTest(xla_test.XLATestCase):
|
||||
|
||||
def _assertForwardOpMatchesExpected(self,
|
||||
image_np,
|
||||
|
@ -22,7 +22,7 @@ import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -36,7 +36,7 @@ CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
|
||||
# Local response normalization tests. The forward tests are copied from
|
||||
# tensorflow/python/kernel_tests/lrn_op_test.py
|
||||
class LRNTest(XLATestCase):
|
||||
class LRNTest(xla_test.XLATestCase):
|
||||
|
||||
def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0,
|
||||
beta=0.5):
|
||||
|
@ -19,14 +19,14 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MatrixBandPartTest(XLATestCase):
|
||||
class MatrixBandPartTest(xla_test.XLATestCase):
|
||||
|
||||
def _testMatrixBandPart(self, dtype, shape):
|
||||
with self.test_session():
|
||||
|
@ -22,7 +22,7 @@ import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -35,7 +35,7 @@ def MakePlaceholder(x):
|
||||
return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape)
|
||||
|
||||
|
||||
class MatrixTriangularSolveOpTest(XLATestCase):
|
||||
class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
|
||||
|
||||
# MatrixTriangularSolve defined for float64, float32, complex64, complex128
|
||||
# (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve)
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import momentum as momentum_lib
|
||||
|
||||
|
||||
class MomentumOptimizerTest(XLATestCase):
|
||||
class MomentumOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum):
|
||||
var += accum * lr * momentum
|
||||
|
@ -22,14 +22,14 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class NAryOpsTest(XLATestCase):
|
||||
class NAryOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def _testNAry(self, op, args, expected, equality_fn=None):
|
||||
with self.test_session() as session:
|
||||
|
@ -20,13 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class NullaryOpsTest(XLATestCase):
|
||||
class NullaryOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def _testNullary(self, op, expected):
|
||||
with self.test_session() as session:
|
||||
|
@ -18,14 +18,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class PlaceholderTest(XLATestCase):
|
||||
class PlaceholderTest(xla_test.XLATestCase):
|
||||
|
||||
def test_placeholder_with_default_default(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -41,7 +41,7 @@ def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding):
|
||||
padding=padding)
|
||||
|
||||
|
||||
class Pooling3DTest(XLATestCase):
|
||||
class Pooling3DTest(xla_test.XLATestCase):
|
||||
|
||||
def _VerifyValues(self, pool_func, input_sizes, window, strides, padding,
|
||||
expected):
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -69,7 +69,7 @@ def GetTestConfigs():
|
||||
return test_configs
|
||||
|
||||
|
||||
class PoolingTest(XLATestCase):
|
||||
class PoolingTest(xla_test.XLATestCase):
|
||||
|
||||
def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, expected):
|
||||
@ -288,7 +288,7 @@ class PoolingTest(XLATestCase):
|
||||
expected=expected_output)
|
||||
|
||||
|
||||
class PoolGradTest(XLATestCase):
|
||||
class PoolGradTest(xla_test.XLATestCase):
|
||||
|
||||
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
|
||||
|
@ -22,7 +22,7 @@ import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -31,7 +31,7 @@ from tensorflow.python.ops.distributions import special_math
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class RandomOpsTest(XLATestCase):
|
||||
class RandomOpsTest(xla_test.XLATestCase):
|
||||
"""Test cases for random-number generating operators."""
|
||||
|
||||
def _random_types(self):
|
||||
|
@ -22,7 +22,7 @@ import functools
|
||||
import itertools
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class ReduceOpsTest(XLATestCase):
|
||||
class ReduceOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def _testReduction(self,
|
||||
tf_reduce_fn,
|
||||
@ -156,7 +156,7 @@ class ReduceOpsTest(XLATestCase):
|
||||
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
|
||||
|
||||
|
||||
class ReduceOpPrecisionTest(XLATestCase):
|
||||
class ReduceOpPrecisionTest(xla_test.XLATestCase):
|
||||
|
||||
def _testReduceSum(self,
|
||||
expected_result,
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.compiler.tf2xla.python import xla
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class ReduceWindowTest(XLATestCase):
|
||||
class ReduceWindowTest(xla_test.XLATestCase):
|
||||
"""Test cases for xla.reduce_window."""
|
||||
|
||||
def _reduce_window(self, operand, init, reducer, **kwargs):
|
||||
|
@ -21,14 +21,14 @@ from __future__ import print_function
|
||||
import itertools
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class ReverseOpsTest(XLATestCase):
|
||||
class ReverseOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def testReverseOneDim(self):
|
||||
shape = (7, 5, 9, 11)
|
||||
|
@ -20,13 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ReverseSequenceTest(XLATestCase):
|
||||
class ReverseSequenceTest(xla_test.XLATestCase):
|
||||
|
||||
def _testReverseSequence(self,
|
||||
x,
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -28,7 +28,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import rmsprop
|
||||
|
||||
|
||||
class RmspropTest(XLATestCase):
|
||||
class RmspropTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
for dtype in self.float_types:
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
@ -69,7 +69,7 @@ def handle_options(func, x, axis, exclusive, reverse):
|
||||
return x
|
||||
|
||||
|
||||
class CumsumTest(XLATestCase):
|
||||
class CumsumTest(xla_test.XLATestCase):
|
||||
|
||||
valid_dtypes = [np.float32]
|
||||
|
||||
@ -147,7 +147,7 @@ class CumsumTest(XLATestCase):
|
||||
math_ops.cumsum(input_tensor, [0]).eval()
|
||||
|
||||
|
||||
class CumprodTest(XLATestCase):
|
||||
class CumprodTest(xla_test.XLATestCase):
|
||||
|
||||
valid_dtypes = [np.float32]
|
||||
|
||||
|
@ -22,7 +22,7 @@ import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -68,7 +68,7 @@ def _NumpyUpdate(indices, updates, shape):
|
||||
return _NumpyScatterNd(ref, indices, updates, lambda p, u: u)
|
||||
|
||||
|
||||
class ScatterNdTest(XLATestCase):
|
||||
class ScatterNdTest(xla_test.XLATestCase):
|
||||
|
||||
def _VariableRankTest(self,
|
||||
np_scatter,
|
||||
|
@ -18,14 +18,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class SliceTest(XLATestCase):
|
||||
class SliceTest(xla_test.XLATestCase):
|
||||
|
||||
def test1D(self):
|
||||
for dtype in self.numeric_types:
|
||||
@ -110,7 +110,7 @@ class SliceTest(XLATestCase):
|
||||
self.assertAllEqual([[[1, 1, 1, 1], [6, 5, 4, 3]]], result)
|
||||
|
||||
|
||||
class StridedSliceTest(XLATestCase):
|
||||
class StridedSliceTest(xla_test.XLATestCase):
|
||||
|
||||
def test1D(self):
|
||||
for dtype in self.numeric_types:
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
@ -68,7 +68,7 @@ def space_to_batch_direct(input_array, block_shape, paddings):
|
||||
return permuted_reshaped_padded.reshape(output_shape)
|
||||
|
||||
|
||||
class SpaceToBatchTest(XLATestCase):
|
||||
class SpaceToBatchTest(xla_test.XLATestCase):
|
||||
"""Tests input-output pairs for the SpaceToBatch and BatchToSpace ops."""
|
||||
|
||||
def _testPad(self, inputs, paddings, block_size, outputs):
|
||||
@ -149,7 +149,7 @@ class SpaceToBatchTest(XLATestCase):
|
||||
self._testOne(x_np, block_size, x_out)
|
||||
|
||||
|
||||
class SpaceToBatchNDTest(XLATestCase):
|
||||
class SpaceToBatchNDTest(xla_test.XLATestCase):
|
||||
"""Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops."""
|
||||
|
||||
def _testPad(self, inputs, block_shape, paddings, outputs):
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -28,7 +28,7 @@ from tensorflow.python.ops import gen_data_flow_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class StackOpTest(XLATestCase):
|
||||
class StackOpTest(xla_test.XLATestCase):
|
||||
|
||||
def testStackPushPop(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
|
@ -22,7 +22,7 @@ import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.contrib import stateless
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.ops.distributions import special_math
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class StatelessRandomOpsTest(XLATestCase):
|
||||
class StatelessRandomOpsTest(xla_test.XLATestCase):
|
||||
"""Test cases for stateless random-number generator operators."""
|
||||
|
||||
def _random_types(self):
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
@ -28,7 +28,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class TernaryOpsTest(XLATestCase):
|
||||
class TernaryOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def _testTernary(self, op, a, b, c, expected):
|
||||
with self.test_session() as session:
|
||||
|
@ -23,7 +23,7 @@ import unittest
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import bitwise_ops
|
||||
@ -44,7 +44,7 @@ def nhwc_to_format(x, data_format):
|
||||
raise ValueError("Unknown format {}".format(data_format))
|
||||
|
||||
|
||||
class UnaryOpsTest(XLATestCase):
|
||||
class UnaryOpsTest(xla_test.XLATestCase):
|
||||
"""Test cases for unary operators."""
|
||||
|
||||
def _assertOpOutputMatchesExpected(self,
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -37,7 +37,7 @@ from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
|
||||
|
||||
|
||||
class VariableOpsTest(XLATestCase):
|
||||
class VariableOpsTest(xla_test.XLATestCase):
|
||||
"""Test cases for resource variable operators."""
|
||||
|
||||
def testOneWriteOneOutput(self):
|
||||
@ -435,7 +435,7 @@ class StridedSliceAssignChecker(object):
|
||||
self.test.assertAllEqual(val, valnp)
|
||||
|
||||
|
||||
class SliceAssignTest(XLATestCase):
|
||||
class SliceAssignTest(xla_test.XLATestCase):
|
||||
|
||||
def testSliceAssign(self):
|
||||
for dtype in self.numeric_types:
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.compiler.tf2xla.python import xla
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class WhileTest(XLATestCase):
|
||||
class WhileTest(xla_test.XLATestCase):
|
||||
|
||||
def testSingletonLoopHandrolled(self):
|
||||
# Define a function for the loop body
|
||||
|
@ -20,14 +20,14 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_control_flow_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class XlaDeviceTest(XLATestCase):
|
||||
class XlaDeviceTest(xla_test.XLATestCase):
|
||||
|
||||
def testCopies(self):
|
||||
"""Tests that copies onto and off XLA devices work."""
|
||||
|
Loading…
Reference in New Issue
Block a user