Add DT_BOOL support to GPU variable ops
This is a follow-on to PR #38848 & PR #39172 and resolves remaining ask in Issue #35994. The original PR tried to add many variable ops on the GPU including DT_BOOL. However, this caused testCondModifyBoolPred to fail and thus the DT_BOOL type was removed. The reason for the test failure is once DT_BOOL variables are supported on the GPU, we need to ensure the switch ops are also updated to not have host memory requirement. Otherwise, a DT_BOOL ref variable is attempted to be copied to the GPU which fails since we should not be transfering ref types. PiperOrigin-RevId: 316577397 Change-Id: Ic0d96ed4cdf8a0ea4674889aaff3a8ecd50991dd
This commit is contained in:
parent
0cc6210daa
commit
ac695a31de
|
@ -111,15 +111,17 @@ REGISTER_GPU_SWITCH(uint64);
|
||||||
TF_CALL_variant(REGISTER_GPU_SWITCH);
|
TF_CALL_variant(REGISTER_GPU_SWITCH);
|
||||||
TF_CALL_uint32(REGISTER_GPU_SWITCH);
|
TF_CALL_uint32(REGISTER_GPU_SWITCH);
|
||||||
TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
|
TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
|
||||||
|
TF_CALL_bool(REGISTER_GPU_SWITCH);
|
||||||
|
TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
|
||||||
|
|
||||||
#undef REGISTER_CPU_SWITCH
|
#undef REGISTER_CPU_SWITCH
|
||||||
#undef REGISTER_CPU_REF_SWITCH
|
#undef REGISTER_CPU_REF_SWITCH
|
||||||
#undef REGISTER_GPU_SWITCH
|
#undef REGISTER_GPU_SWITCH
|
||||||
#undef REGISTER_GPU_REF_SWITCH
|
#undef REGISTER_GPU_REF_SWITCH
|
||||||
|
|
||||||
// Special GPU kernels for int32 and string.
|
// Special GPU kernels for int32, string & resource handles. Requiring all
|
||||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
// inputs and outputs to be in host memory.
|
||||||
// registration requires all int32 inputs and outputs to be in host memory.
|
// TODO(b/25387198): Also enable int32 in device memory.
|
||||||
#define REGISTER_GPU_HOST_KERNEL(type) \
|
#define REGISTER_GPU_HOST_KERNEL(type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("Switch") \
|
REGISTER_KERNEL_BUILDER(Name("Switch") \
|
||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
|
@ -149,8 +151,6 @@ TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
|
||||||
|
|
||||||
REGISTER_GPU_HOST_KERNEL(int32);
|
REGISTER_GPU_HOST_KERNEL(int32);
|
||||||
REGISTER_GPU_HOST_REF_KERNEL(int32);
|
REGISTER_GPU_HOST_REF_KERNEL(int32);
|
||||||
REGISTER_GPU_HOST_KERNEL(bool);
|
|
||||||
REGISTER_GPU_HOST_REF_KERNEL(bool);
|
|
||||||
REGISTER_GPU_HOST_KERNEL(tstring);
|
REGISTER_GPU_HOST_KERNEL(tstring);
|
||||||
REGISTER_GPU_HOST_REF_KERNEL(tstring);
|
REGISTER_GPU_HOST_REF_KERNEL(tstring);
|
||||||
REGISTER_GPU_HOST_KERNEL(ResourceHandle);
|
REGISTER_GPU_HOST_KERNEL(ResourceHandle);
|
||||||
|
|
|
@ -252,8 +252,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
|
||||||
|
|
||||||
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
||||||
TF_CALL_uint32(REGISTER_GPU_KERNELS);
|
TF_CALL_uint32(REGISTER_GPU_KERNELS);
|
||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
|
||||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
|
|
||||||
#undef REGISTER_GPU_KERNELS
|
#undef REGISTER_GPU_KERNELS
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
|
|
@ -73,9 +73,9 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||||
for attr_key in new_node.attr:
|
for attr_key in new_node.attr:
|
||||||
if attr_key == "parallel_iterations":
|
if attr_key == "parallel_iterations":
|
||||||
new_node.attr[attr_key].i = 1
|
new_node.attr[attr_key].i = 1
|
||||||
elif new_node.op == "Switch":
|
elif new_node.op == "Switch" or new_node.op == "Identity":
|
||||||
# We don't check the inputs to Switch ops as their inputs may be
|
# We don't check the inputs to Switch or Identity ops as their inputs
|
||||||
# Send/Recv nodes.
|
# may be Send/Recv nodes.
|
||||||
del new_node.input[:]
|
del new_node.input[:]
|
||||||
|
|
||||||
return output_graph_def
|
return output_graph_def
|
||||||
|
|
|
@ -396,10 +396,10 @@ class CondTest(test_util.TensorFlowTestCase):
|
||||||
fn2=lambda: math_ops.add(y, 23))
|
fn2=lambda: math_ops.add(y, 23))
|
||||||
self.assertEquals(self.evaluate(z), 24)
|
self.assertEquals(self.evaluate(z), 24)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_v1_only("Exercises Ref variables")
|
||||||
def testCondModifyBoolPred(self):
|
def testCondModifyBoolPred(self):
|
||||||
# This test in particular used to fail only when running in GPU, hence
|
# We want to use the GPU here because we want to ensure that we can update
|
||||||
# use_gpu=True.
|
# a boolean ref variable on the GPU.
|
||||||
with test_util.use_gpu():
|
with test_util.use_gpu():
|
||||||
bool_var = variable_scope.get_variable(
|
bool_var = variable_scope.get_variable(
|
||||||
"bool_var", dtype=dtypes.bool, initializer=True)
|
"bool_var", dtype=dtypes.bool, initializer=True)
|
||||||
|
|
Loading…
Reference in New Issue