Automated rollback of commit 8f6c5d3252

PiperOrigin-RevId: 224364244
This commit is contained in:
Akshay Modi 2018-12-06 10:36:56 -08:00 committed by TensorFlower Gardener
parent e576acf5db
commit 05b2aaacf1
2 changed files with 20 additions and 0 deletions

View File

@ -220,6 +220,14 @@ TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype) {
return nullptr; return nullptr;
} }
} }
tensorflow::Safe_PyObjectPtr value_decrefer;
if (PyArray_IsScalar(value, Generic)) {
// Convert numpy scalars to numpy arrays.
value = PyArray_FromScalar(value, nullptr);
// The returned value needs to be DECREF'd, but the original value was
// created in python code, and doesn't need to be DECREF'd.
value_decrefer.reset(value);
}
if (PyArray_Check(value)) { if (PyArray_Check(value)) {
int desired_np_dtype = -1; int desired_np_dtype = -1;
if (desired_dtype >= 0) { if (desired_dtype >= 0) {

View File

@ -95,6 +95,18 @@ class TFETensorTest(test_util.TensorFlowTestCase):
t = _create_tensor(values) t = _create_tensor(values)
self.assertAllEqual(values, t) self.assertAllEqual(values, t)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testNumpyDtypeSurvivesThroughTensorConversion(self):
scalar_creators = [np.int32, np.int64, np.float32, np.float64]
conversion_functions = [ops.convert_to_tensor, constant_op.constant]
for scalar_creator in scalar_creators:
for conversion_function in conversion_functions:
np_val = scalar_creator(3)
tensor_val = conversion_function(np_val)
self.assertEqual(tensor_val.numpy().dtype, np_val.dtype)
self.assertEqual(tensor_val.numpy(), np_val)
def testNumpyValueWithCast(self): def testNumpyValueWithCast(self):
values = np.array([3.0], dtype=np.float32) values = np.array([3.0], dtype=np.float32)
t = _create_tensor(values, dtype=dtypes.float64) t = _create_tensor(values, dtype=dtypes.float64)