diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 930e62b6809..c88c0f52eea 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -347,7 +347,7 @@ class Mean(Metric): Raises: ValueError: if the optional argument is not bool """ - # Convert the boolean to tensor for tf.cond, if it is not. + # Convert the boolean to tensor for tf.cond, if it is not. if not isinstance(write_summary, ops.Tensor): write_summary = ops.convert_to_tensor(write_summary) t = self.numer / self.denom @@ -487,6 +487,8 @@ class BinaryAccuracy(Mean): message="Shapes of labels and predictions are unequal") predictions = ops.convert_to_tensor(predictions) predictions = predictions > self.threshold + # Convert labels to bool to match predictions. + labels = math_ops.cast(labels, dtypes.bool) matches = math_ops.equal(labels, predictions) matches = math_ops.cast(matches, self.dtype) super(BinaryAccuracy, self).call(matches, weights=weights) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index c898869dd0e..7a441bd8834 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -27,6 +27,7 @@ cc_library( "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:tape", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/python:cpp_python_util", diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 71ede721edf..5f18ab27b7e 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -574,10 +574,14 @@ def _num_elements(grad): raise ValueError("`grad` not a Tensor or IndexedSlices.") +def _cast_constant(value, dtype): + return math_ops.cast(constant_op.constant(value), dtype) + + def _fast_fill(value, shape, dtype): return array_ops.fill( - constant_op.constant(shape, dtype=dtypes.int32), - constant_op.constant(value, dtype=dtype)) + _cast_constant(shape, dtype=dtypes.int32), + _cast_constant(value, dtype=dtype)) def _zeros(shape, dtype): @@ -605,7 +609,7 @@ def _ones(shape, dtype): return array_ops.ones(shape, dtype) if shape == (): # pylint: disable=g-explicit-bool-comparison - return constant_op.constant(1, dtype=dtype) + return _cast_constant(1, dtype=dtype) return _fast_fill(1, shape, dtype) diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 5f44bd4fecd..bf145d92309 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" +#include "tensorflow/core/framework/types.h" + #include "structmember.h" // NOLINT // For PyMemberDef // forward declare @@ -135,6 +137,39 @@ PyObject* PyIntFromDataType(TF_DataType l) { } // namespace namespace tensorflow { +// This function checks whether the desired type is "compatible" with the +// inferred type. At a high level, compatibility means that all integral types +// are compatible with each other, and all floating types are compatible with +// each other. +// +// Type compatibility doesn't consider overflows (i.e. int64 is *always* +// compatible with int32). This is intended to match graph behavior. +bool IsCompatible(int desired_dtype, TF_DataType returned_dtype) { + tensorflow::DataType desired = + static_cast(desired_dtype); + tensorflow::DataType returned = + static_cast(returned_dtype); + + if (desired == returned) return true; + + if (tensorflow::DataTypeIsInteger(desired) && + tensorflow::DataTypeIsInteger(returned)) { + return true; + } else if (tensorflow::DataTypeIsFloating(desired) && + tensorflow::DataTypeIsFloating(returned)) { + return true; + } else if (tensorflow::DataTypeIsComplex(desired) && + (tensorflow::DataTypeIsComplex(returned) || + tensorflow::DataTypeIsInteger(returned) || + tensorflow::DataTypeIsFloating(returned))) { + return true; + } else if (tensorflow::DataTypeIsQuantized(desired) && + tensorflow::DataTypeIsInteger(returned)) { + return true; + } + return false; +} + // Casts data referred to by `handle` from type `src_type_enum` to type // `dst_type_enum`. TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle, @@ -376,20 +411,33 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { if (handle == nullptr) return -1; TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get()); if (desired_dtype >= 0 && desired_dtype != handle_dtype) { - handle = tensorflow::make_safe(tensorflow::EagerCast( - GetContext(context), handle.get(), handle_dtype, - static_cast(desired_dtype), self->status)); - if (TF_GetCode(self->status) != TF_OK) { - PyErr_SetString(PyExc_TypeError, - tensorflow::strings::StrCat( - "Error while casting from DataType ", handle_dtype, - " to ", desired_dtype, ". ", TF_Message(self->status)) - .c_str()); - // Cleanup self->status before returning. - TF_SetStatus(self->status, TF_OK, ""); + // Check type compatibility. + if (tensorflow::IsCompatible(desired_dtype, handle_dtype)) { + handle = tensorflow::make_safe(tensorflow::EagerCast( + GetContext(context), handle.get(), handle_dtype, + static_cast(desired_dtype), self->status)); + if (TF_GetCode(self->status) != TF_OK) { + PyErr_SetString( + PyExc_TypeError, + tensorflow::strings::StrCat("Error while casting from DataType ", + handle_dtype, " to ", desired_dtype, + ". ", TF_Message(self->status)) + .c_str()); + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); + return -1; + } + handle_dtype = TFE_TensorHandleDataType(handle.get()); + } else { + tensorflow::Safe_PyObjectPtr value_str(PyObject_Str(value)); + PyErr_SetString( + PyExc_TypeError, + tensorflow::strings::StrCat( + "Cannot convert value ", TFE_GetPythonString(value_str.get()), + " to EagerTensor with requested dtype: ", desired_dtype) + .c_str()); return -1; } - handle_dtype = TFE_TensorHandleDataType(handle.get()); } // Almost all TensorFlow kernels for GPU devices keep int32 tensors in host diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h index 4eaa1ba5362..f90fd9bbb68 100644 --- a/tensorflow/python/eager/pywrap_tensor.h +++ b/tensorflow/python/eager/pywrap_tensor.h @@ -26,6 +26,7 @@ tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor); tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor); namespace tensorflow { +bool IsCompatible(int desired_dtype, TF_DataType returned_dtype); TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype); // TODO(nareshmodi): Move EagerCast and ReadVariableOp (which use the C API to diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 9cee561b32c..70de5e0c03e 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/python/eager/pywrap_tfe.h" #include "absl/strings/str_cat.h" @@ -2113,7 +2114,9 @@ bool CastTensor(const FastPathOpExecInfo& op_exec_info, *handle = tensorflow::make_safe( tensorflow::EagerCast(op_exec_info.ctx, handle->get(), input_dtype, static_cast(desired_dtype), status)); - if (!status->status.ok()) return false; + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { + return false; + } output_dtype = desired_dtype; } @@ -2122,7 +2125,9 @@ bool CastTensor(const FastPathOpExecInfo& op_exec_info, // if copying to the same device. *handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice( handle->get(), op_exec_info.ctx, op_exec_info.device_name, status)); - if (!status->status.ok()) return false; + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { + return false; + } } return true; } @@ -2205,14 +2210,19 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, return true; } -// Supports only 2 cases at the moment: -// i) input is an EagerTensor +// Supports 3 cases at the moment: +// i) input is an EagerTensor. // ii) input is a ResourceVariable - in this case, the is_variable param is // set to true. +// iii) input is an arbitrary python list/tuple (note, this handling doesn't +// support packing). // // NOTE: dtype_hint_getter must *always* return a PyObject that can be // decref'd. So if no hint is found, Py_RETURN_NONE (which correctly // increfs Py_None). +// +// NOTE: This function sets a python error directly, and returns false. +// TF_Status is only passed since we don't want to have to reallocate it. bool ConvertToTensor( const FastPathOpExecInfo& op_exec_info, PyObject* input, tensorflow::Safe_PyObjectPtr* output_handle, @@ -2239,25 +2249,43 @@ bool ConvertToTensor( tensorflow::make_safe(static_cast( tensorflow::ConvertToEagerTensor(input, dtype_hint.get()))); if (handle == nullptr) { - status->status = tensorflow::errors::InvalidArgument( - "Unable to convert value to tensor"); - return false; + return MaybeRaiseExceptionFromTFStatus(status, nullptr); } int desired_dtype = -1; if (dtype_hint.get() != Py_None) { if (!ParseTypeValue("", dtype_hint.get(), status, &desired_dtype)) { - status->status = tensorflow::errors::InvalidArgument( - "Expecting a DataType value for dtype. Got ", - Py_TYPE(dtype_hint.get())->tp_name); + PyErr_SetString(PyExc_TypeError, + tensorflow::strings::StrCat( + "Expecting a DataType value for dtype. Got ", + Py_TYPE(dtype_hint.get())->tp_name) + .c_str()); + return false; } } - if (!CastTensor(op_exec_info, static_cast(desired_dtype), - &handle, status)) { - return false; - } + // Maybe cast to the desired type. This is intended to match python + // convert_to_tensor behavior. TF_DataType output_dtype = TFE_TensorHandleDataType(handle.get()); + if (desired_dtype >= 0 && desired_dtype != output_dtype) { + if (tensorflow::IsCompatible(desired_dtype, output_dtype)) { + if (!CastTensor(op_exec_info, static_cast(desired_dtype), + &handle, status)) { + return false; + } + output_dtype = TFE_TensorHandleDataType(handle.get()); + } else { + tensorflow::Safe_PyObjectPtr input_str(PyObject_Str(input)); + PyErr_SetString( + PyExc_TypeError, + tensorflow::strings::StrCat( + "Cannot convert value ", TFE_GetPythonString(input_str.get()), + " to EagerTensor with requested dtype: ", desired_dtype) + .c_str()); + return false; + } + } + output_handle->reset(EagerTensorFromHandle(handle.release())); dtype_setter(output_dtype); diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index cd3d05d05e3..4fe6c44d059 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -139,8 +139,8 @@ class TFETensorTest(test_util.TensorFlowTestCase): self.assertAllEqual(t, 1.0) def testConstantDtype(self): - self.assertEqual(constant_op.constant(1.0, dtype=np.int64).dtype, - dtypes.int64) + self.assertEqual( + constant_op.constant(1, dtype=np.int64).dtype, dtypes.int64) def testTensorAndNumpyMatrix(self): expected = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) @@ -250,6 +250,61 @@ class TFETensorTest(test_util.TensorFlowTestCase): with self.assertRaises(UnicodeDecodeError): io_ops.read_file(b"\xff") + @test_util.run_in_graph_and_eager_modes + def testConvertToTensorPreferredDtypeIsRespected(self): + self.assertEqual( + ops.convert_to_tensor(0.5, preferred_dtype=dtypes.int32).dtype, + dtypes.float32) + self.assertEqual( + ops.convert_to_tensor(0.5, preferred_dtype=dtypes.float64).dtype, + dtypes.float64) + + @test_util.run_in_graph_and_eager_modes + def testCompatibility(self): + # TODO(nareshmodi): uint32, uint64 are not correctly handled in graph mode. + integer_types = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, + dtypes.uint8, dtypes.uint16] + + # Floats are not compatible with ints + for t in integer_types: + with self.assertRaises(TypeError): + constant_op.constant(0.5, dtype=t) + + # Ints compatible with floats + self.assertEqual( + self.evaluate(constant_op.constant(5, dtype=dtypes.float16)), 5.0) + self.assertEqual( + self.evaluate(constant_op.constant(5, dtype=dtypes.float32)), 5.0) + self.assertEqual( + self.evaluate(constant_op.constant(5, dtype=dtypes.float64)), 5.0) + + # Ints and floats are compatible with complex types + self.assertEqual( + constant_op.constant([[1.0]], dtype=dtypes.complex128).dtype, + dtypes.complex128) + self.assertEqual( + constant_op.constant([[1]], dtype=dtypes.complex128).dtype, + dtypes.complex128) + + # Quantized types are not compatible with floats + quantized_types = [dtypes.qint16, dtypes.qint32, dtypes.qint8, + dtypes.quint16, dtypes.quint8] + + for t in quantized_types: + with self.assertRaises(TypeError): + constant_op.constant(0.5, dtype=t) + + # TODO(b/118402529): quantized types are broken in eager. + + @test_util.run_in_graph_and_eager_modes + def testCConvertToTensor(self): + with self.assertRaises(TypeError): + _ = constant_op.constant(0) < 0.5 + + @test_util.run_in_graph_and_eager_modes + def testConvertToTensorAllowsOverflow(self): + _ = ops.convert_to_tensor(123456789, dtype=dtypes.uint8) + class TFETensorUtilTest(test_util.TensorFlowTestCase):