parent
e576acf5db
commit
05b2aaacf1
@ -220,6 +220,14 @@ TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
tensorflow::Safe_PyObjectPtr value_decrefer;
|
||||||
|
if (PyArray_IsScalar(value, Generic)) {
|
||||||
|
// Convert numpy scalars to numpy arrays.
|
||||||
|
value = PyArray_FromScalar(value, nullptr);
|
||||||
|
// The returned value needs to be DECREF'd, but the original value was
|
||||||
|
// created in python code, and doesn't need to be DECREF'd.
|
||||||
|
value_decrefer.reset(value);
|
||||||
|
}
|
||||||
if (PyArray_Check(value)) {
|
if (PyArray_Check(value)) {
|
||||||
int desired_np_dtype = -1;
|
int desired_np_dtype = -1;
|
||||||
if (desired_dtype >= 0) {
|
if (desired_dtype >= 0) {
|
||||||
|
@ -95,6 +95,18 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||||||
t = _create_tensor(values)
|
t = _create_tensor(values)
|
||||||
self.assertAllEqual(values, t)
|
self.assertAllEqual(values, t)
|
||||||
|
|
||||||
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||||
|
def testNumpyDtypeSurvivesThroughTensorConversion(self):
|
||||||
|
scalar_creators = [np.int32, np.int64, np.float32, np.float64]
|
||||||
|
conversion_functions = [ops.convert_to_tensor, constant_op.constant]
|
||||||
|
|
||||||
|
for scalar_creator in scalar_creators:
|
||||||
|
for conversion_function in conversion_functions:
|
||||||
|
np_val = scalar_creator(3)
|
||||||
|
tensor_val = conversion_function(np_val)
|
||||||
|
self.assertEqual(tensor_val.numpy().dtype, np_val.dtype)
|
||||||
|
self.assertEqual(tensor_val.numpy(), np_val)
|
||||||
|
|
||||||
def testNumpyValueWithCast(self):
|
def testNumpyValueWithCast(self):
|
||||||
values = np.array([3.0], dtype=np.float32)
|
values = np.array([3.0], dtype=np.float32)
|
||||||
t = _create_tensor(values, dtype=dtypes.float64)
|
t = _create_tensor(values, dtype=dtypes.float64)
|
||||||
|
Loading…
Reference in New Issue
Block a user