Remove special handling of host-memory/device-memory for int32 arguments in type lists.
The current special case behavior is preserved only for functions and their gradients. Change: 130100547
This commit is contained in:
parent
55b44e625c
commit
84cefad9cc
@ -7,6 +7,9 @@
|
|||||||
default, simply pass the argument `state_is_tuple=False`.
|
default, simply pass the argument `state_is_tuple=False`.
|
||||||
* DeviceFactory's AddDevices and CreateDevices functions now return
|
* DeviceFactory's AddDevices and CreateDevices functions now return
|
||||||
a Status instead of void.
|
a Status instead of void.
|
||||||
|
* Int32 elements of list(type) arguments are no longer placed in host memory by
|
||||||
|
default. If necessary, a list(type) argument to a kernel can be placed in host
|
||||||
|
memory using a HostMemory annotation.
|
||||||
|
|
||||||
# Release 0.10.0
|
# Release 0.10.0
|
||||||
|
|
||||||
|
@ -61,29 +61,6 @@ MemoryType MTypeFromDType(const DataType dtype) {
|
|||||||
return (dtype == DT_INT32) ? HOST_MEMORY : DEVICE_MEMORY;
|
return (dtype == DT_INT32) ? HOST_MEMORY : DEVICE_MEMORY;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the default memory types for type list arguments from the data
|
|
||||||
// types. (The default can be overridden by an explicit HostMemory()
|
|
||||||
// declaration.)
|
|
||||||
Status SetTypeListMTypesFromDTypes(
|
|
||||||
const NameRangeMap& name_ranges,
|
|
||||||
const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
|
|
||||||
const DataTypeVector& dtypes, MemoryTypeVector* mtypes) {
|
|
||||||
for (const auto& a : args) {
|
|
||||||
if (!a.type_list_attr().empty()) {
|
|
||||||
auto it = name_ranges.find(a.name());
|
|
||||||
if (it == name_ranges.end()) {
|
|
||||||
return errors::InvalidArgument("Name range for argument ", a.name(),
|
|
||||||
" not found.");
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = it->second.first; i < it->second.second; ++i) {
|
|
||||||
(*mtypes)[i] = MTypeFromDType(dtypes[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
|
Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
|
||||||
@ -107,12 +84,13 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
|
|||||||
inp_mtypes->clear();
|
inp_mtypes->clear();
|
||||||
out_mtypes->clear();
|
out_mtypes->clear();
|
||||||
|
|
||||||
if (!status.ok()) {
|
// For functions (which have no KernelDef) and their gradients, we can only
|
||||||
// When there is no kernel def for this op, we can only best-effort derive
|
// best-effort derive the memory type from the data type. For now, we assume
|
||||||
// the memory type from the data type. For now, we assume int32 is always
|
// int32 is always on host memory and other types are always on device memory.
|
||||||
// on host memory and other types are always on device memory. We should
|
// TODO(zhifengc,phawkins): We should do type inference over function bodies
|
||||||
// do type inference over function body to derive the correct
|
// to derive the correct input/output memory types. We should also split
|
||||||
// input/output memory types.
|
// host-memory and non host-memory arguments into separate type lists.
|
||||||
|
if (!status.ok() || ndef.op() == "SymbolicGradient") {
|
||||||
for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t));
|
for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t));
|
||||||
for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t));
|
for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -127,12 +105,6 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
|
|||||||
inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY);
|
inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY);
|
||||||
out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY);
|
out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY);
|
||||||
|
|
||||||
// For type list arguments, mark int32 arguments as host memory.
|
|
||||||
TF_RETURN_IF_ERROR(SetTypeListMTypesFromDTypes(inp_names, op_def->input_arg(),
|
|
||||||
inp_dtypes, inp_mtypes));
|
|
||||||
TF_RETURN_IF_ERROR(SetTypeListMTypesFromDTypes(
|
|
||||||
out_names, op_def->output_arg(), out_dtypes, out_mtypes));
|
|
||||||
|
|
||||||
// Fills in host memory types based on the kernel def.
|
// Fills in host memory types based on the kernel def.
|
||||||
const auto& from_proto = kdef->host_memory_arg();
|
const auto& from_proto = kdef->host_memory_arg();
|
||||||
std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
|
std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
|
||||||
|
@ -63,11 +63,11 @@ TEST(MemoryTypesForNode, Simple) {
|
|||||||
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
|
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
|
||||||
&input, &output));
|
&input, &output));
|
||||||
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
||||||
DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
|
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
||||||
DEVICE_MEMORY, HOST_MEMORY}),
|
DEVICE_MEMORY, DEVICE_MEMORY}),
|
||||||
input);
|
input);
|
||||||
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
||||||
HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY}),
|
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
|
||||||
output);
|
output);
|
||||||
|
|
||||||
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def,
|
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def,
|
||||||
@ -77,7 +77,7 @@ TEST(MemoryTypesForNode, Simple) {
|
|||||||
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
|
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
|
||||||
input);
|
input);
|
||||||
EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
||||||
HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY}),
|
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
|
||||||
output);
|
output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,12 +109,20 @@ REGISTER_KERNEL_BUILDER(Name("_Retval")
|
|||||||
|
|
||||||
class PassOn : public OpKernel {
|
class PassOn : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
|
||||||
OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
|
OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
|
||||||
errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
|
errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
|
||||||
" vs. ", ctx->num_outputs()));
|
" vs. ", ctx->num_outputs()));
|
||||||
|
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, input_type(i) == output_type(i),
|
||||||
|
errors::Internal("Input and output types for position ", i,
|
||||||
|
" do not match: ", DataTypeString(input_type(i)),
|
||||||
|
" vs. ", DataTypeString(output_type(i))));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||||
ctx->set_output(i, ctx->input(i));
|
ctx->set_output(i, ctx->input(i));
|
||||||
}
|
}
|
||||||
@ -140,12 +148,14 @@ REGISTER_GPU_KERNELS(double);
|
|||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("_ListToArray")
|
REGISTER_KERNEL_BUILDER(Name("_ListToArray")
|
||||||
.Device(DEVICE_GPU)
|
.Device(DEVICE_GPU)
|
||||||
|
.HostMemory("input")
|
||||||
.HostMemory("output")
|
.HostMemory("output")
|
||||||
.TypeConstraint<int32>("T"),
|
.TypeConstraint<int32>("T"),
|
||||||
PassOn);
|
PassOn);
|
||||||
REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
|
REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
|
||||||
.Device(DEVICE_GPU)
|
.Device(DEVICE_GPU)
|
||||||
.HostMemory("input")
|
.HostMemory("input")
|
||||||
|
.HostMemory("output")
|
||||||
.TypeConstraint<int32>("T"),
|
.TypeConstraint<int32>("T"),
|
||||||
PassOn);
|
PassOn);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user