Remove tensor input shape from function signature.

PiperOrigin-RevId: 265973257
This commit is contained in:
Tong Shen 2019-08-28 12:55:54 -07:00 committed by TensorFlower Gardener
parent ccfe164602
commit bb4cacfe45
6 changed files with 4 additions and 47 deletions

View File

@ -469,10 +469,6 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
IsMultiDevice(ctx->FindFunctionDef(op->Name())); IsMultiDevice(ctx->FindFunctionDef(op->Name()));
std::vector<Device*> input_dev_ptrs; std::vector<Device*> input_dev_ptrs;
// `input_tensor_shapes` contains (potentially a subset of) non DT_RESOURCE
// arguments, and `input_resource_variable_dtypes_and_shapes` contains shapes
// and underlying types for (potentially a subset) of DT_RESOURCE arguments.
std::unordered_map<int, TensorShape> input_tensor_shapes;
std::unordered_map<int, DtypeAndPartialTensorShape> std::unordered_map<int, DtypeAndPartialTensorShape>
input_resource_variable_dtypes_and_shapes; input_resource_variable_dtypes_and_shapes;
if (is_multi_device_function) { if (is_multi_device_function) {
@ -507,19 +503,9 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
cache_key = cache_key =
FingerprintCat128(cache_key, Fingerprint128(input_device->name())); FingerprintCat128(cache_key, Fingerprint128(input_device->name()));
// If input is normal tensor, get its shape and add it to 'cache_key';
// If input is a ResourceHandle, get its resource handle dtypes and shapes // If input is a ResourceHandle, get its resource handle dtypes and shapes
// and add them to 'cache_key'. // and add them to 'cache_key'.
if (input->dtype != DT_RESOURCE) { if (input->dtype == DT_RESOURCE) {
TensorShape shape;
TF_RETURN_IF_ERROR(input->Shape(&shape));
input_tensor_shapes[i] = shape;
// Add both _Arg index and shape to "cache_key".
cache_key = FingerprintCat128(cache_key, i);
AppendTensorShapeToFingerprint(shape, &cache_key);
} else {
// We only care about data type and shape for resource variable inputs. // We only care about data type and shape for resource variable inputs.
// But we have no way to tell if input is resource variable (other than // But we have no way to tell if input is resource variable (other than
// looking it up in ResourceMgr, which is slow). So we just get // looking it up in ResourceMgr, which is slow). So we just get
@ -596,7 +582,6 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
<< ". Full node_def=" << ndef.DebugString(); << ". Full node_def=" << ndef.DebugString();
kernel.reset(new KernelAndDeviceFunc( kernel.reset(new KernelAndDeviceFunc(
flr, ctx->pflr(), std::move(input_dev_ptrs), flr, ctx->pflr(), std::move(input_dev_ptrs),
std::move(input_tensor_shapes),
std::move(input_resource_variable_dtypes_and_shapes), runner, std::move(input_resource_variable_dtypes_and_shapes), runner,
ctx->GetCollectiveExecutorHandle(), ctx->HostCPU(), op->Name(), ctx->GetCollectiveExecutorHandle(), ctx->HostCPU(), op->Name(),
[ctx](const int64 step_id) { [ctx](const int64 step_id) {

View File

@ -124,7 +124,6 @@ Status KernelAndDeviceFunc::Init(const NodeDef& ndef,
for (const Device* device : input_devices_) { for (const Device* device : input_devices_) {
options.input_devices.push_back(device->name()); options.input_devices.push_back(device->name());
} }
options.input_tensor_shapes = input_tensor_shapes_;
options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_; options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_;
const auto& it = ndef.attr().find("executor_type"); const auto& it = ndef.attr().find("executor_type");

View File

@ -187,7 +187,6 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
KernelAndDeviceFunc( KernelAndDeviceFunc(
FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr, FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
std::vector<Device*> input_devices, std::vector<Device*> input_devices,
std::unordered_map<int, TensorShape> input_tensor_shapes,
std::unordered_map<int, DtypeAndPartialTensorShape> std::unordered_map<int, DtypeAndPartialTensorShape>
input_resource_dtypes_and_shapes, input_resource_dtypes_and_shapes,
std::function<void(std::function<void()>)>* runner, std::function<void(std::function<void()>)>* runner,
@ -199,7 +198,6 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
pflr_(pflr), pflr_(pflr),
handle_(kInvalidHandle), handle_(kInvalidHandle),
input_devices_(std::move(input_devices)), input_devices_(std::move(input_devices)),
input_tensor_shapes_(std::move(input_tensor_shapes)),
input_resource_dtypes_and_shapes_( input_resource_dtypes_and_shapes_(
std::move(input_resource_dtypes_and_shapes)), std::move(input_resource_dtypes_and_shapes)),
name_(name), name_(name),
@ -242,7 +240,6 @@ class KernelAndDeviceFunc final : public KernelAndDevice {
// CPU devices are not null. Resource handles' devices are actual backing // CPU devices are not null. Resource handles' devices are actual backing
// devices. // devices.
std::vector<Device*> input_devices_; std::vector<Device*> input_devices_;
std::unordered_map<int, TensorShape> input_tensor_shapes_;
std::unordered_map<int, DtypeAndPartialTensorShape> std::unordered_map<int, DtypeAndPartialTensorShape>
input_resource_dtypes_and_shapes_; input_resource_dtypes_and_shapes_;

View File

@ -317,7 +317,6 @@ const string* AssignedOrRequestedDeviceName(const Node& node) {
} }
Status SetArgShape( Status SetArgShape(
const std::unordered_map<int, TensorShape>& input_tensor_shapes,
const std::unordered_map<int, DtypeAndPartialTensorShape>& const std::unordered_map<int, DtypeAndPartialTensorShape>&
input_resource_dtypes_and_shapes, input_resource_dtypes_and_shapes,
const std::vector<Node*>& arg_nodes) { const std::vector<Node*>& arg_nodes) {
@ -326,16 +325,7 @@ Status SetArgShape(
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
DataType dtype; DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype)); TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
if (dtype != DT_RESOURCE) { if (dtype == DT_RESOURCE) {
auto shape_iter = input_tensor_shapes.find(index);
if (shape_iter != input_tensor_shapes.end()) {
TensorShapeProto shape_proto;
shape_iter->second.AsProto(&shape_proto);
AttrValue attr_value;
*attr_value.mutable_list()->add_shape() = shape_proto;
n->AddAttr("_output_shapes", attr_value);
}
} else {
auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index); auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index);
if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) { if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) {
AttrValue dtype_attr_value; AttrValue dtype_attr_value;
@ -626,9 +616,8 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
options.graph_collector->CollectRawGraph(def); options.graph_collector->CollectRawGraph(def);
} }
TF_RETURN_IF_ERROR(SetArgShape(options.input_tensor_shapes, TF_RETURN_IF_ERROR(
options.input_resource_dtypes_and_shapes, SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
arg_nodes));
TF_RETURN_IF_ERROR(PinArgsAndRets(options.input_devices, TF_RETURN_IF_ERROR(PinArgsAndRets(options.input_devices,
options.output_devices, device_set_, options.output_devices, device_set_,
arg_nodes, ret_nodes)); arg_nodes, ret_nodes));

View File

@ -921,11 +921,6 @@ string Canonicalize(const string& funcname, AttrSlice attrs,
entries.push_back(strings::StrCat( entries.push_back(strings::StrCat(
"_output_dev", i, "=", absl::CEscape(options.output_devices[i]))); "_output_dev", i, "=", absl::CEscape(options.output_devices[i])));
} }
for (const auto& iter : options.input_tensor_shapes) {
entries.push_back(
strings::StrCat("_input_tensor_shape", iter.first, "=",
absl::CEscape(iter.second.DebugString())));
}
for (const auto& iter : options.input_resource_dtypes_and_shapes) { for (const auto& iter : options.input_resource_dtypes_and_shapes) {
entries.push_back(strings::StrCat("_input_resource_dtype", iter.first, "=", entries.push_back(strings::StrCat("_input_resource_dtype", iter.first, "=",
DataTypeString(iter.second.dtype))); DataTypeString(iter.second.dtype)));

View File

@ -563,14 +563,6 @@ class FunctionLibraryRuntime {
// infer correct device. // infer correct device.
std::vector<string> output_devices; std::vector<string> output_devices;
// This interface is EXPERIMENTAL and subject to change.
//
// For multi-device functions, a mapping from _Arg node index to input
// tensor shape.
// REQUIRES: if input_tensor_shapes.count(i) > 0 then i-th argument type
// must not be DT_RESOURCE.
std::unordered_map<int, TensorShape> input_tensor_shapes;
// This interface is EXPERIMENTAL and subject to change. // This interface is EXPERIMENTAL and subject to change.
// //
// For multi-device functions, a mapping from _Arg node index to type and // For multi-device functions, a mapping from _Arg node index to type and