diff --git a/tensorflow/python/distribute/collective_util.py b/tensorflow/python/distribute/collective_util.py index 0d4554480b5..4fef896a326 100644 --- a/tensorflow/python/distribute/collective_util.py +++ b/tensorflow/python/distribute/collective_util.py @@ -81,7 +81,11 @@ class _OptionsExported(object): """ def __new__(cls, *args, **kwargs): - return Options.__new__(Options, *args, **kwargs) + # We expose a dummy class so that we can separate internal and public APIs. + # Note that __init__ won't be called on the returned object if it's a + # different class [1]. + # [1] https://docs.python.org/3/reference/datamodel.html#object.__new__ + return Options(*args, **kwargs) def __init__(self, bytes_per_pack=0, diff --git a/tensorflow/python/distribute/collective_util_test.py b/tensorflow/python/distribute/collective_util_test.py index e75d520979b..984442901fb 100644 --- a/tensorflow/python/distribute/collective_util_test.py +++ b/tensorflow/python/distribute/collective_util_test.py @@ -25,8 +25,11 @@ from tensorflow.python.eager import test class OptionsTest(test.TestCase): def testCreateOptionsViaExportedAPI(self): - options = collective_util._OptionsExported() + options = collective_util._OptionsExported(bytes_per_pack=1) self.assertIsInstance(options, collective_util.Options) + self.assertEqual(options.bytes_per_pack, 1) + with self.assertRaises(ValueError): + collective_util._OptionsExported(bytes_per_pack=-1) def testCreateOptionsViaHints(self): with self.assertLogs() as cm: