Add DT_BOOL support to GPU variable ops

This is a follow-on to PR #38848 & PR #39172 and resolves remaining ask
in Issue #35994. The original PR tried to add many variable ops on the
GPU including DT_BOOL. However, this caused testCondModifyBoolPred to
fail and thus the DT_BOOL type was removed. The reason for the test
failure is once DT_BOOL variables are supported on the GPU, we need to
ensure the switch ops are also updated to not have host memory
requirement. Otherwise, a DT_BOOL ref variable is attempted to be
copied to the GPU which fails since we should not be transfering ref
types.

PiperOrigin-RevId: 316577397
Change-Id: Ic0d96ed4cdf8a0ea4674889aaff3a8ecd50991dd
This commit is contained in:
Gaurav Jain 2020-06-15 17:16:40 -07:00 committed by TensorFlower Gardener
parent 0cc6210daa
commit ac695a31de
4 changed files with 12 additions and 13 deletions

View File

@ -111,15 +111,17 @@ REGISTER_GPU_SWITCH(uint64);
TF_CALL_variant(REGISTER_GPU_SWITCH);
TF_CALL_uint32(REGISTER_GPU_SWITCH);
TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
TF_CALL_bool(REGISTER_GPU_SWITCH);
TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
#undef REGISTER_CPU_SWITCH
#undef REGISTER_CPU_REF_SWITCH
#undef REGISTER_GPU_SWITCH
#undef REGISTER_GPU_REF_SWITCH
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
// Special GPU kernels for int32, string & resource handles. Requiring all
// inputs and outputs to be in host memory.
// TODO(b/25387198): Also enable int32 in device memory.
#define REGISTER_GPU_HOST_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Switch") \
.Device(DEVICE_GPU) \
@ -149,8 +151,6 @@ TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
REGISTER_GPU_HOST_KERNEL(int32);
REGISTER_GPU_HOST_REF_KERNEL(int32);
REGISTER_GPU_HOST_KERNEL(bool);
REGISTER_GPU_HOST_REF_KERNEL(bool);
REGISTER_GPU_HOST_KERNEL(tstring);
REGISTER_GPU_HOST_REF_KERNEL(tstring);
REGISTER_GPU_HOST_KERNEL(ResourceHandle);

View File

@ -252,8 +252,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_uint32(REGISTER_GPU_KERNELS);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -73,9 +73,9 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
for attr_key in new_node.attr:
if attr_key == "parallel_iterations":
new_node.attr[attr_key].i = 1
elif new_node.op == "Switch":
# We don't check the inputs to Switch ops as their inputs may be
# Send/Recv nodes.
elif new_node.op == "Switch" or new_node.op == "Identity":
# We don't check the inputs to Switch or Identity ops as their inputs
# may be Send/Recv nodes.
del new_node.input[:]
return output_graph_def

View File

@ -396,10 +396,10 @@ class CondTest(test_util.TensorFlowTestCase):
fn2=lambda: math_ops.add(y, 23))
self.assertEquals(self.evaluate(z), 24)
@test_util.run_deprecated_v1
@test_util.run_v1_only("Exercises Ref variables")
def testCondModifyBoolPred(self):
# This test in particular used to fail only when running in GPU, hence
# use_gpu=True.
# We want to use the GPU here because we want to ensure that we can update
# a boolean ref variable on the GPU.
with test_util.use_gpu():
bool_var = variable_scope.get_variable(
"bool_var", dtype=dtypes.bool, initializer=True)