Make the fastpath correctly respect default types.
Removes some of the discrepancy found in #30113 PiperOrigin-RevId: 282078959 Change-Id: Ia765e90cc2f5365de98106e22e599984294baa74
This commit is contained in:
parent
d8b566f2d6
commit
7807ec92bf
@ -621,6 +621,7 @@ tf_py_test(
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:test_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -144,6 +144,40 @@ AttrToInputsMap* GetAttrToInputsMap(const tensorflow::OpDef& op_def) {
|
||||
return retval;
|
||||
}
|
||||
|
||||
tensorflow::mutex all_attr_to_defaults_maps_lock(
|
||||
tensorflow::LINKER_INITIALIZED);
|
||||
tensorflow::gtl::FlatMap<
|
||||
string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>*
|
||||
GetAllAttrToDefaultsMaps() {
|
||||
static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap<
|
||||
string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>;
|
||||
return all_attr_to_defaults_maps;
|
||||
}
|
||||
|
||||
tensorflow::gtl::FlatMap<string, tensorflow::DataType>* GetAttrToDefaultsMap(
|
||||
const tensorflow::OpDef& op_def) {
|
||||
tensorflow::mutex_lock l(all_attr_to_defaults_maps_lock);
|
||||
auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps();
|
||||
|
||||
auto* output =
|
||||
tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name());
|
||||
if (output != nullptr) {
|
||||
return output;
|
||||
}
|
||||
|
||||
auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>;
|
||||
|
||||
for (const auto& attr : op_def.attr()) {
|
||||
if (attr.type() == "type" && attr.has_default_value()) {
|
||||
new_map->insert({attr.name(), attr.default_value().type()});
|
||||
}
|
||||
}
|
||||
|
||||
(*all_attr_to_defaults_maps)[op_def.name()] = new_map;
|
||||
|
||||
return new_map;
|
||||
}
|
||||
|
||||
struct FastPathOpExecInfo {
|
||||
TFE_Context* ctx;
|
||||
const char* device_name;
|
||||
@ -164,6 +198,7 @@ struct FastPathOpExecInfo {
|
||||
// DTypes can come from another input that has the same attr. So build that
|
||||
// map.
|
||||
const AttrToInputsMap* attr_to_inputs_map;
|
||||
const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes;
|
||||
tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
|
||||
};
|
||||
|
||||
@ -969,9 +1004,7 @@ const char* TFE_GetPythonString(PyObject* o) {
|
||||
#endif
|
||||
}
|
||||
|
||||
int64_t get_uid() {
|
||||
return _uid++;
|
||||
}
|
||||
int64_t get_uid() { return _uid++; }
|
||||
|
||||
PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
|
||||
|
||||
@ -2838,6 +2871,11 @@ tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
|
||||
}
|
||||
}
|
||||
|
||||
auto default_it = op_exec_info->default_dtypes->find(attr);
|
||||
if (default_it != op_exec_info->default_dtypes->end()) {
|
||||
return default_it->second;
|
||||
}
|
||||
|
||||
return tensorflow::DT_INVALID;
|
||||
}
|
||||
|
||||
@ -3499,6 +3537,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
|
||||
}
|
||||
|
||||
op_exec_info.attr_to_inputs_map = GetAttrToInputsMap(*op_def);
|
||||
op_exec_info.default_dtypes = GetAttrToDefaultsMap(*op_def);
|
||||
|
||||
// Mapping of attr name to size - used to calculate the number of values
|
||||
// to be expected by the TFE_Execute run.
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
@ -27,6 +29,7 @@ from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -268,6 +271,43 @@ class Tests(test.TestCase):
|
||||
"transpose_a", False,
|
||||
"transpose_b", False)
|
||||
|
||||
def testOpDefDefaultType(self):
|
||||
im = np.random.randint(
|
||||
low=0, high=65535, size=100, dtype=np.uint16).reshape(10, 10, 1)
|
||||
|
||||
context.ensure_initialized()
|
||||
|
||||
fastpath_dtype = test_ops.dtype_with_default_op(im).numpy()
|
||||
slowpath_dtype = test_ops.dtype_with_default_op_eager_fallback(
|
||||
im, None, context.context()).numpy()
|
||||
# Ensure the fastpath and slowpath eager paths work.
|
||||
self.assertEqual(fastpath_dtype, slowpath_dtype)
|
||||
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
graph_dtype_symbolic = test_ops.dtype_with_default_op(im)
|
||||
|
||||
graph_dtype = self.evaluate(graph_dtype_symbolic)
|
||||
# Ensure the eager path matches the graph path.
|
||||
self.assertEqual(fastpath_dtype, graph_dtype)
|
||||
|
||||
# Unfortunately, as of now, this doesn't work as expected on def_functions,
|
||||
# since we convert the numpy arrays to tensors pre-tracing (which won't get
|
||||
# overriddent by the default type).
|
||||
@def_function.function
|
||||
def func(im):
|
||||
return test_ops.dtype_with_default_op(im)
|
||||
|
||||
function_dtype = func(im).numpy()
|
||||
self.assertNotEqual(fastpath_dtype, function_dtype)
|
||||
|
||||
# Captures are OK, since they don't go through the conversion path.
|
||||
@def_function.function
|
||||
def func_captured():
|
||||
return test_ops.dtype_with_default_op(im)
|
||||
|
||||
function_dtype = func_captured().numpy()
|
||||
self.assertEqual(fastpath_dtype, function_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -690,4 +690,29 @@ REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_CPU),
|
||||
DevicePlacementOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_GPU),
|
||||
DevicePlacementOp);
|
||||
|
||||
// An op which returns the dtype of the tensor it was passed in. It expects
|
||||
// DT_UINT8.
|
||||
REGISTER_OP("DtypeWithDefaultOp")
|
||||
.Input("in: T")
|
||||
.Attr("T: type = DT_UINT8")
|
||||
.Output("dtype: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
class DTypeWithDefaultOp : public OpKernel {
|
||||
public:
|
||||
using OpKernel::OpKernel;
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& input = ctx->input(0);
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output("dtype", TensorShape({}), &output));
|
||||
output->scalar<tstring>()() = tensorflow::DataTypeString(input.dtype());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DtypeWithDefaultOp").Device(DEVICE_CPU),
|
||||
DTypeWithDefaultOp);
|
||||
} // end namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user