Remove tensor input shape from function signature.
PiperOrigin-RevId: 265973257
This commit is contained in:
parent
ccfe164602
commit
bb4cacfe45
@ -469,10 +469,6 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
IsMultiDevice(ctx->FindFunctionDef(op->Name()));
|
IsMultiDevice(ctx->FindFunctionDef(op->Name()));
|
||||||
|
|
||||||
std::vector<Device*> input_dev_ptrs;
|
std::vector<Device*> input_dev_ptrs;
|
||||||
// `input_tensor_shapes` contains (potentially a subset of) non DT_RESOURCE
|
|
||||||
// arguments, and `input_resource_variable_dtypes_and_shapes` contains shapes
|
|
||||||
// and underlying types for (potentially a subset) of DT_RESOURCE arguments.
|
|
||||||
std::unordered_map<int, TensorShape> input_tensor_shapes;
|
|
||||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||||
input_resource_variable_dtypes_and_shapes;
|
input_resource_variable_dtypes_and_shapes;
|
||||||
if (is_multi_device_function) {
|
if (is_multi_device_function) {
|
||||||
@ -507,19 +503,9 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
cache_key =
|
cache_key =
|
||||||
FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
|
FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
|
||||||
|
|
||||||
// If input is normal tensor, get its shape and add it to 'cache_key';
|
|
||||||
// If input is a ResourceHandle, get its resource handle dtypes and shapes
|
// If input is a ResourceHandle, get its resource handle dtypes and shapes
|
||||||
// and add them to 'cache_key'.
|
// and add them to 'cache_key'.
|
||||||
if (input->dtype != DT_RESOURCE) {
|
if (input->dtype == DT_RESOURCE) {
|
||||||
TensorShape shape;
|
|
||||||
TF_RETURN_IF_ERROR(input->Shape(&shape));
|
|
||||||
|
|
||||||
input_tensor_shapes[i] = shape;
|
|
||||||
|
|
||||||
// Add both _Arg index and shape to "cache_key".
|
|
||||||
cache_key = FingerprintCat128(cache_key, i);
|
|
||||||
AppendTensorShapeToFingerprint(shape, &cache_key);
|
|
||||||
} else {
|
|
||||||
// We only care about data type and shape for resource variable inputs.
|
// We only care about data type and shape for resource variable inputs.
|
||||||
// But we have no way to tell if input is resource variable (other than
|
// But we have no way to tell if input is resource variable (other than
|
||||||
// looking it up in ResourceMgr, which is slow). So we just get
|
// looking it up in ResourceMgr, which is slow). So we just get
|
||||||
@ -596,7 +582,6 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
<< ". Full node_def=" << ndef.DebugString();
|
<< ". Full node_def=" << ndef.DebugString();
|
||||||
kernel.reset(new KernelAndDeviceFunc(
|
kernel.reset(new KernelAndDeviceFunc(
|
||||||
flr, ctx->pflr(), std::move(input_dev_ptrs),
|
flr, ctx->pflr(), std::move(input_dev_ptrs),
|
||||||
std::move(input_tensor_shapes),
|
|
||||||
std::move(input_resource_variable_dtypes_and_shapes), runner,
|
std::move(input_resource_variable_dtypes_and_shapes), runner,
|
||||||
ctx->GetCollectiveExecutorHandle(), ctx->HostCPU(), op->Name(),
|
ctx->GetCollectiveExecutorHandle(), ctx->HostCPU(), op->Name(),
|
||||||
[ctx](const int64 step_id) {
|
[ctx](const int64 step_id) {
|
||||||
|
@ -124,7 +124,6 @@ Status KernelAndDeviceFunc::Init(const NodeDef& ndef,
|
|||||||
for (const Device* device : input_devices_) {
|
for (const Device* device : input_devices_) {
|
||||||
options.input_devices.push_back(device->name());
|
options.input_devices.push_back(device->name());
|
||||||
}
|
}
|
||||||
options.input_tensor_shapes = input_tensor_shapes_;
|
|
||||||
options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_;
|
options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_;
|
||||||
|
|
||||||
const auto& it = ndef.attr().find("executor_type");
|
const auto& it = ndef.attr().find("executor_type");
|
||||||
|
@ -187,7 +187,6 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
|||||||
KernelAndDeviceFunc(
|
KernelAndDeviceFunc(
|
||||||
FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
|
FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
|
||||||
std::vector<Device*> input_devices,
|
std::vector<Device*> input_devices,
|
||||||
std::unordered_map<int, TensorShape> input_tensor_shapes,
|
|
||||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||||
input_resource_dtypes_and_shapes,
|
input_resource_dtypes_and_shapes,
|
||||||
std::function<void(std::function<void()>)>* runner,
|
std::function<void(std::function<void()>)>* runner,
|
||||||
@ -199,7 +198,6 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
|||||||
pflr_(pflr),
|
pflr_(pflr),
|
||||||
handle_(kInvalidHandle),
|
handle_(kInvalidHandle),
|
||||||
input_devices_(std::move(input_devices)),
|
input_devices_(std::move(input_devices)),
|
||||||
input_tensor_shapes_(std::move(input_tensor_shapes)),
|
|
||||||
input_resource_dtypes_and_shapes_(
|
input_resource_dtypes_and_shapes_(
|
||||||
std::move(input_resource_dtypes_and_shapes)),
|
std::move(input_resource_dtypes_and_shapes)),
|
||||||
name_(name),
|
name_(name),
|
||||||
@ -242,7 +240,6 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
|||||||
// CPU devices are not null. Resource handles' devices are actual backing
|
// CPU devices are not null. Resource handles' devices are actual backing
|
||||||
// devices.
|
// devices.
|
||||||
std::vector<Device*> input_devices_;
|
std::vector<Device*> input_devices_;
|
||||||
std::unordered_map<int, TensorShape> input_tensor_shapes_;
|
|
||||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||||
input_resource_dtypes_and_shapes_;
|
input_resource_dtypes_and_shapes_;
|
||||||
|
|
||||||
|
@ -317,7 +317,6 @@ const string* AssignedOrRequestedDeviceName(const Node& node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status SetArgShape(
|
Status SetArgShape(
|
||||||
const std::unordered_map<int, TensorShape>& input_tensor_shapes,
|
|
||||||
const std::unordered_map<int, DtypeAndPartialTensorShape>&
|
const std::unordered_map<int, DtypeAndPartialTensorShape>&
|
||||||
input_resource_dtypes_and_shapes,
|
input_resource_dtypes_and_shapes,
|
||||||
const std::vector<Node*>& arg_nodes) {
|
const std::vector<Node*>& arg_nodes) {
|
||||||
@ -326,16 +325,7 @@ Status SetArgShape(
|
|||||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
|
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
|
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
|
||||||
if (dtype != DT_RESOURCE) {
|
if (dtype == DT_RESOURCE) {
|
||||||
auto shape_iter = input_tensor_shapes.find(index);
|
|
||||||
if (shape_iter != input_tensor_shapes.end()) {
|
|
||||||
TensorShapeProto shape_proto;
|
|
||||||
shape_iter->second.AsProto(&shape_proto);
|
|
||||||
AttrValue attr_value;
|
|
||||||
*attr_value.mutable_list()->add_shape() = shape_proto;
|
|
||||||
n->AddAttr("_output_shapes", attr_value);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index);
|
auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index);
|
||||||
if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) {
|
if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) {
|
||||||
AttrValue dtype_attr_value;
|
AttrValue dtype_attr_value;
|
||||||
@ -626,9 +616,8 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
|||||||
options.graph_collector->CollectRawGraph(def);
|
options.graph_collector->CollectRawGraph(def);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(SetArgShape(options.input_tensor_shapes,
|
TF_RETURN_IF_ERROR(
|
||||||
options.input_resource_dtypes_and_shapes,
|
SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
|
||||||
arg_nodes));
|
|
||||||
TF_RETURN_IF_ERROR(PinArgsAndRets(options.input_devices,
|
TF_RETURN_IF_ERROR(PinArgsAndRets(options.input_devices,
|
||||||
options.output_devices, device_set_,
|
options.output_devices, device_set_,
|
||||||
arg_nodes, ret_nodes));
|
arg_nodes, ret_nodes));
|
||||||
|
@ -921,11 +921,6 @@ string Canonicalize(const string& funcname, AttrSlice attrs,
|
|||||||
entries.push_back(strings::StrCat(
|
entries.push_back(strings::StrCat(
|
||||||
"_output_dev", i, "=", absl::CEscape(options.output_devices[i])));
|
"_output_dev", i, "=", absl::CEscape(options.output_devices[i])));
|
||||||
}
|
}
|
||||||
for (const auto& iter : options.input_tensor_shapes) {
|
|
||||||
entries.push_back(
|
|
||||||
strings::StrCat("_input_tensor_shape", iter.first, "=",
|
|
||||||
absl::CEscape(iter.second.DebugString())));
|
|
||||||
}
|
|
||||||
for (const auto& iter : options.input_resource_dtypes_and_shapes) {
|
for (const auto& iter : options.input_resource_dtypes_and_shapes) {
|
||||||
entries.push_back(strings::StrCat("_input_resource_dtype", iter.first, "=",
|
entries.push_back(strings::StrCat("_input_resource_dtype", iter.first, "=",
|
||||||
DataTypeString(iter.second.dtype)));
|
DataTypeString(iter.second.dtype)));
|
||||||
|
@ -563,14 +563,6 @@ class FunctionLibraryRuntime {
|
|||||||
// infer correct device.
|
// infer correct device.
|
||||||
std::vector<string> output_devices;
|
std::vector<string> output_devices;
|
||||||
|
|
||||||
// This interface is EXPERIMENTAL and subject to change.
|
|
||||||
//
|
|
||||||
// For multi-device functions, a mapping from _Arg node index to input
|
|
||||||
// tensor shape.
|
|
||||||
// REQUIRES: if input_tensor_shapes.count(i) > 0 then i-th argument type
|
|
||||||
// must not be DT_RESOURCE.
|
|
||||||
std::unordered_map<int, TensorShape> input_tensor_shapes;
|
|
||||||
|
|
||||||
// This interface is EXPERIMENTAL and subject to change.
|
// This interface is EXPERIMENTAL and subject to change.
|
||||||
//
|
//
|
||||||
// For multi-device functions, a mapping from _Arg node index to type and
|
// For multi-device functions, a mapping from _Arg node index to type and
|
||||||
|
Loading…
Reference in New Issue
Block a user