diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 19c12b2c8d4..f5898c55e61 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -621,6 +621,7 @@ tf_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:random_ops", + "//tensorflow/python:test_ops", ], ) diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index a7e6aad7b2c..f5f336c2323 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -144,6 +144,40 @@ AttrToInputsMap* GetAttrToInputsMap(const tensorflow::OpDef& op_def) { return retval; } +tensorflow::mutex all_attr_to_defaults_maps_lock( + tensorflow::LINKER_INITIALIZED); +tensorflow::gtl::FlatMap< + string, tensorflow::gtl::FlatMap*>* +GetAllAttrToDefaultsMaps() { + static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap< + string, tensorflow::gtl::FlatMap*>; + return all_attr_to_defaults_maps; +} + +tensorflow::gtl::FlatMap* GetAttrToDefaultsMap( + const tensorflow::OpDef& op_def) { + tensorflow::mutex_lock l(all_attr_to_defaults_maps_lock); + auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps(); + + auto* output = + tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name()); + if (output != nullptr) { + return output; + } + + auto* new_map = new tensorflow::gtl::FlatMap; + + for (const auto& attr : op_def.attr()) { + if (attr.type() == "type" && attr.has_default_value()) { + new_map->insert({attr.name(), attr.default_value().type()}); + } + } + + (*all_attr_to_defaults_maps)[op_def.name()] = new_map; + + return new_map; +} + struct FastPathOpExecInfo { TFE_Context* ctx; const char* device_name; @@ -164,6 +198,7 @@ struct FastPathOpExecInfo { // DTypes can come from another input that has the same attr. So build that // map. const AttrToInputsMap* attr_to_inputs_map; + const tensorflow::gtl::FlatMap* default_dtypes; tensorflow::gtl::FlatMap cached_dtypes; }; @@ -969,9 +1004,7 @@ const char* TFE_GetPythonString(PyObject* o) { #endif } -int64_t get_uid() { - return _uid++; -} +int64_t get_uid() { return _uid++; } PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); } @@ -2838,6 +2871,11 @@ tensorflow::DataType MaybeGetDTypeForAttr(const string& attr, } } + auto default_it = op_exec_info->default_dtypes->find(attr); + if (default_it != op_exec_info->default_dtypes->end()) { + return default_it->second; + } + return tensorflow::DT_INVALID; } @@ -3499,6 +3537,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { } op_exec_info.attr_to_inputs_map = GetAttrToInputsMap(*op_def); + op_exec_info.default_dtypes = GetAttrToDefaultsMap(*op_def); // Mapping of attr name to size - used to calculate the number of values // to be expected by the TFE_Execute run. diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index 5299d1ecebe..e29d9b7321a 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -27,6 +29,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -268,6 +271,43 @@ class Tests(test.TestCase): "transpose_a", False, "transpose_b", False) + def testOpDefDefaultType(self): + im = np.random.randint( + low=0, high=65535, size=100, dtype=np.uint16).reshape(10, 10, 1) + + context.ensure_initialized() + + fastpath_dtype = test_ops.dtype_with_default_op(im).numpy() + slowpath_dtype = test_ops.dtype_with_default_op_eager_fallback( + im, None, context.context()).numpy() + # Ensure the fastpath and slowpath eager paths work. + self.assertEqual(fastpath_dtype, slowpath_dtype) + + with ops.Graph().as_default(), self.cached_session(): + graph_dtype_symbolic = test_ops.dtype_with_default_op(im) + + graph_dtype = self.evaluate(graph_dtype_symbolic) + # Ensure the eager path matches the graph path. + self.assertEqual(fastpath_dtype, graph_dtype) + + # Unfortunately, as of now, this doesn't work as expected on def_functions, + # since we convert the numpy arrays to tensors pre-tracing (which won't get + # overriddent by the default type). + @def_function.function + def func(im): + return test_ops.dtype_with_default_op(im) + + function_dtype = func(im).numpy() + self.assertNotEqual(fastpath_dtype, function_dtype) + + # Captures are OK, since they don't go through the conversion path. + @def_function.function + def func_captured(): + return test_ops.dtype_with_default_op(im) + + function_dtype = func_captured().numpy() + self.assertEqual(fastpath_dtype, function_dtype) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc index fc864692b7b..8c57520a4cf 100644 --- a/tensorflow/python/framework/test_ops.cc +++ b/tensorflow/python/framework/test_ops.cc @@ -690,4 +690,29 @@ REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_CPU), DevicePlacementOp); REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_GPU), DevicePlacementOp); + +// An op which returns the dtype of the tensor it was passed in. It expects +// DT_UINT8. +REGISTER_OP("DtypeWithDefaultOp") + .Input("in: T") + .Attr("T: type = DT_UINT8") + .Output("dtype: string") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape); + +class DTypeWithDefaultOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* ctx) override { + const Tensor& input = ctx->input(0); + Tensor* output; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("dtype", TensorShape({}), &output)); + output->scalar()() = tensorflow::DataTypeString(input.dtype()); + } +}; + +REGISTER_KERNEL_BUILDER(Name("DtypeWithDefaultOp").Device(DEVICE_CPU), + DTypeWithDefaultOp); } // end namespace tensorflow