Make the fastpath correctly respect default types.

Removes some of the discrepancy found in #30113

PiperOrigin-RevId: 282078959
Change-Id: Ia765e90cc2f5365de98106e22e599984294baa74
This commit is contained in:
Akshay Modi 2019-11-22 18:11:56 -08:00 committed by TensorFlower Gardener
parent d8b566f2d6
commit 7807ec92bf
4 changed files with 108 additions and 3 deletions

View File

@ -621,6 +621,7 @@ tf_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:random_ops",
"//tensorflow/python:test_ops",
],
)

View File

@ -144,6 +144,40 @@ AttrToInputsMap* GetAttrToInputsMap(const tensorflow::OpDef& op_def) {
return retval;
}
tensorflow::mutex all_attr_to_defaults_maps_lock(
tensorflow::LINKER_INITIALIZED);
tensorflow::gtl::FlatMap<
string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>*
GetAllAttrToDefaultsMaps() {
static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap<
string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>;
return all_attr_to_defaults_maps;
}
tensorflow::gtl::FlatMap<string, tensorflow::DataType>* GetAttrToDefaultsMap(
const tensorflow::OpDef& op_def) {
tensorflow::mutex_lock l(all_attr_to_defaults_maps_lock);
auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps();
auto* output =
tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name());
if (output != nullptr) {
return output;
}
auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>;
for (const auto& attr : op_def.attr()) {
if (attr.type() == "type" && attr.has_default_value()) {
new_map->insert({attr.name(), attr.default_value().type()});
}
}
(*all_attr_to_defaults_maps)[op_def.name()] = new_map;
return new_map;
}
struct FastPathOpExecInfo {
TFE_Context* ctx;
const char* device_name;
@ -164,6 +198,7 @@ struct FastPathOpExecInfo {
// DTypes can come from another input that has the same attr. So build that
// map.
const AttrToInputsMap* attr_to_inputs_map;
const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes;
tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
};
@ -969,9 +1004,7 @@ const char* TFE_GetPythonString(PyObject* o) {
#endif
}
int64_t get_uid() {
return _uid++;
}
int64_t get_uid() { return _uid++; }
PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
@ -2838,6 +2871,11 @@ tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
}
}
auto default_it = op_exec_info->default_dtypes->find(attr);
if (default_it != op_exec_info->default_dtypes->end()) {
return default_it->second;
}
return tensorflow::DT_INVALID;
}
@ -3499,6 +3537,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
}
op_exec_info.attr_to_inputs_map = GetAttrToInputsMap(*op_def);
op_exec_info.default_dtypes = GetAttrToDefaultsMap(*op_def);
// Mapping of attr name to size - used to calculate the number of values
// to be expected by the TFE_Execute run.

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@ -27,6 +29,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -268,6 +271,43 @@ class Tests(test.TestCase):
"transpose_a", False,
"transpose_b", False)
def testOpDefDefaultType(self):
im = np.random.randint(
low=0, high=65535, size=100, dtype=np.uint16).reshape(10, 10, 1)
context.ensure_initialized()
fastpath_dtype = test_ops.dtype_with_default_op(im).numpy()
slowpath_dtype = test_ops.dtype_with_default_op_eager_fallback(
im, None, context.context()).numpy()
# Ensure the fastpath and slowpath eager paths work.
self.assertEqual(fastpath_dtype, slowpath_dtype)
with ops.Graph().as_default(), self.cached_session():
graph_dtype_symbolic = test_ops.dtype_with_default_op(im)
graph_dtype = self.evaluate(graph_dtype_symbolic)
# Ensure the eager path matches the graph path.
self.assertEqual(fastpath_dtype, graph_dtype)
# Unfortunately, as of now, this doesn't work as expected on def_functions,
# since we convert the numpy arrays to tensors pre-tracing (which won't get
# overriddent by the default type).
@def_function.function
def func(im):
return test_ops.dtype_with_default_op(im)
function_dtype = func(im).numpy()
self.assertNotEqual(fastpath_dtype, function_dtype)
# Captures are OK, since they don't go through the conversion path.
@def_function.function
def func_captured():
return test_ops.dtype_with_default_op(im)
function_dtype = func_captured().numpy()
self.assertEqual(fastpath_dtype, function_dtype)
if __name__ == "__main__":
test.main()

View File

@ -690,4 +690,29 @@ REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_CPU),
DevicePlacementOp);
REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_GPU),
DevicePlacementOp);
// An op which returns the dtype of the tensor it was passed in. It expects
// DT_UINT8.
REGISTER_OP("DtypeWithDefaultOp")
.Input("in: T")
.Attr("T: type = DT_UINT8")
.Output("dtype: string")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
class DTypeWithDefaultOp : public OpKernel {
public:
using OpKernel::OpKernel;
void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0);
Tensor* output;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("dtype", TensorShape({}), &output));
output->scalar<tstring>()() = tensorflow::DataTypeString(input.dtype());
}
};
REGISTER_KERNEL_BUILDER(Name("DtypeWithDefaultOp").Device(DEVICE_CPU),
DTypeWithDefaultOp);
} // end namespace tensorflow