Remove tensor input shape from function signature.
PiperOrigin-RevId: 265973257
This commit is contained in:
parent
ccfe164602
commit
bb4cacfe45
@ -469,10 +469,6 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
IsMultiDevice(ctx->FindFunctionDef(op->Name()));
|
||||
|
||||
std::vector<Device*> input_dev_ptrs;
|
||||
// `input_tensor_shapes` contains (potentially a subset of) non DT_RESOURCE
|
||||
// arguments, and `input_resource_variable_dtypes_and_shapes` contains shapes
|
||||
// and underlying types for (potentially a subset) of DT_RESOURCE arguments.
|
||||
std::unordered_map<int, TensorShape> input_tensor_shapes;
|
||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||
input_resource_variable_dtypes_and_shapes;
|
||||
if (is_multi_device_function) {
|
||||
@ -507,19 +503,9 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
cache_key =
|
||||
FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
|
||||
|
||||
// If input is normal tensor, get its shape and add it to 'cache_key';
|
||||
// If input is a ResourceHandle, get its resource handle dtypes and shapes
|
||||
// and add them to 'cache_key'.
|
||||
if (input->dtype != DT_RESOURCE) {
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(input->Shape(&shape));
|
||||
|
||||
input_tensor_shapes[i] = shape;
|
||||
|
||||
// Add both _Arg index and shape to "cache_key".
|
||||
cache_key = FingerprintCat128(cache_key, i);
|
||||
AppendTensorShapeToFingerprint(shape, &cache_key);
|
||||
} else {
|
||||
if (input->dtype == DT_RESOURCE) {
|
||||
// We only care about data type and shape for resource variable inputs.
|
||||
// But we have no way to tell if input is resource variable (other than
|
||||
// looking it up in ResourceMgr, which is slow). So we just get
|
||||
@ -596,7 +582,6 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
<< ". Full node_def=" << ndef.DebugString();
|
||||
kernel.reset(new KernelAndDeviceFunc(
|
||||
flr, ctx->pflr(), std::move(input_dev_ptrs),
|
||||
std::move(input_tensor_shapes),
|
||||
std::move(input_resource_variable_dtypes_and_shapes), runner,
|
||||
ctx->GetCollectiveExecutorHandle(), ctx->HostCPU(), op->Name(),
|
||||
[ctx](const int64 step_id) {
|
||||
|
@ -124,7 +124,6 @@ Status KernelAndDeviceFunc::Init(const NodeDef& ndef,
|
||||
for (const Device* device : input_devices_) {
|
||||
options.input_devices.push_back(device->name());
|
||||
}
|
||||
options.input_tensor_shapes = input_tensor_shapes_;
|
||||
options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_;
|
||||
|
||||
const auto& it = ndef.attr().find("executor_type");
|
||||
|
@ -187,7 +187,6 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
||||
KernelAndDeviceFunc(
|
||||
FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
|
||||
std::vector<Device*> input_devices,
|
||||
std::unordered_map<int, TensorShape> input_tensor_shapes,
|
||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||
input_resource_dtypes_and_shapes,
|
||||
std::function<void(std::function<void()>)>* runner,
|
||||
@ -199,7 +198,6 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
||||
pflr_(pflr),
|
||||
handle_(kInvalidHandle),
|
||||
input_devices_(std::move(input_devices)),
|
||||
input_tensor_shapes_(std::move(input_tensor_shapes)),
|
||||
input_resource_dtypes_and_shapes_(
|
||||
std::move(input_resource_dtypes_and_shapes)),
|
||||
name_(name),
|
||||
@ -242,7 +240,6 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
|
||||
// CPU devices are not null. Resource handles' devices are actual backing
|
||||
// devices.
|
||||
std::vector<Device*> input_devices_;
|
||||
std::unordered_map<int, TensorShape> input_tensor_shapes_;
|
||||
std::unordered_map<int, DtypeAndPartialTensorShape>
|
||||
input_resource_dtypes_and_shapes_;
|
||||
|
||||
|
@ -317,7 +317,6 @@ const string* AssignedOrRequestedDeviceName(const Node& node) {
|
||||
}
|
||||
|
||||
Status SetArgShape(
|
||||
const std::unordered_map<int, TensorShape>& input_tensor_shapes,
|
||||
const std::unordered_map<int, DtypeAndPartialTensorShape>&
|
||||
input_resource_dtypes_and_shapes,
|
||||
const std::vector<Node*>& arg_nodes) {
|
||||
@ -326,16 +325,7 @@ Status SetArgShape(
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
|
||||
if (dtype != DT_RESOURCE) {
|
||||
auto shape_iter = input_tensor_shapes.find(index);
|
||||
if (shape_iter != input_tensor_shapes.end()) {
|
||||
TensorShapeProto shape_proto;
|
||||
shape_iter->second.AsProto(&shape_proto);
|
||||
AttrValue attr_value;
|
||||
*attr_value.mutable_list()->add_shape() = shape_proto;
|
||||
n->AddAttr("_output_shapes", attr_value);
|
||||
}
|
||||
} else {
|
||||
if (dtype == DT_RESOURCE) {
|
||||
auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index);
|
||||
if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) {
|
||||
AttrValue dtype_attr_value;
|
||||
@ -626,9 +616,8 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
options.graph_collector->CollectRawGraph(def);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(SetArgShape(options.input_tensor_shapes,
|
||||
options.input_resource_dtypes_and_shapes,
|
||||
arg_nodes));
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
|
||||
TF_RETURN_IF_ERROR(PinArgsAndRets(options.input_devices,
|
||||
options.output_devices, device_set_,
|
||||
arg_nodes, ret_nodes));
|
||||
|
@ -921,11 +921,6 @@ string Canonicalize(const string& funcname, AttrSlice attrs,
|
||||
entries.push_back(strings::StrCat(
|
||||
"_output_dev", i, "=", absl::CEscape(options.output_devices[i])));
|
||||
}
|
||||
for (const auto& iter : options.input_tensor_shapes) {
|
||||
entries.push_back(
|
||||
strings::StrCat("_input_tensor_shape", iter.first, "=",
|
||||
absl::CEscape(iter.second.DebugString())));
|
||||
}
|
||||
for (const auto& iter : options.input_resource_dtypes_and_shapes) {
|
||||
entries.push_back(strings::StrCat("_input_resource_dtype", iter.first, "=",
|
||||
DataTypeString(iter.second.dtype)));
|
||||
|
@ -563,14 +563,6 @@ class FunctionLibraryRuntime {
|
||||
// infer correct device.
|
||||
std::vector<string> output_devices;
|
||||
|
||||
// This interface is EXPERIMENTAL and subject to change.
|
||||
//
|
||||
// For multi-device functions, a mapping from _Arg node index to input
|
||||
// tensor shape.
|
||||
// REQUIRES: if input_tensor_shapes.count(i) > 0 then i-th argument type
|
||||
// must not be DT_RESOURCE.
|
||||
std::unordered_map<int, TensorShape> input_tensor_shapes;
|
||||
|
||||
// This interface is EXPERIMENTAL and subject to change.
|
||||
//
|
||||
// For multi-device functions, a mapping from _Arg node index to type and
|
||||
|
Loading…
Reference in New Issue
Block a user