Rollforward: Add DT_BOOL support to GPU variable ops
Identity on bool had a HostMemory requirement which was causing excessive copies. PiperOrigin-RevId: 317413034 Change-Id: Ica75743c9d202f5cc5fb8c12a475eda84507f0be
This commit is contained in:
parent
3427843d70
commit
f0d0485b0d
|
@ -107,15 +107,17 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
|
|||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
|
||||
TF_CALL_variant(REGISTER_GPU_SWITCH);
|
||||
TF_CALL_bool(REGISTER_GPU_SWITCH);
|
||||
TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
|
||||
|
||||
#undef REGISTER_CPU_SWITCH
|
||||
#undef REGISTER_CPU_REF_SWITCH
|
||||
#undef REGISTER_GPU_SWITCH
|
||||
#undef REGISTER_GPU_REF_SWITCH
|
||||
|
||||
// Special GPU kernels for int32 and string.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
// registration requires all int32 inputs and outputs to be in host memory.
|
||||
// Special GPU kernels for int32, string & resource handles. Requiring all
|
||||
// inputs and outputs to be in host memory.
|
||||
// TODO(b/25387198): Also enable int32 in device memory.
|
||||
#define REGISTER_GPU_HOST_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Switch") \
|
||||
.Device(DEVICE_GPU) \
|
||||
|
@ -145,8 +147,6 @@ TF_CALL_variant(REGISTER_GPU_SWITCH);
|
|||
|
||||
REGISTER_GPU_HOST_KERNEL(int32);
|
||||
REGISTER_GPU_HOST_REF_KERNEL(int32);
|
||||
REGISTER_GPU_HOST_KERNEL(bool);
|
||||
REGISTER_GPU_HOST_REF_KERNEL(bool);
|
||||
REGISTER_GPU_HOST_KERNEL(tstring);
|
||||
REGISTER_GPU_HOST_REF_KERNEL(tstring);
|
||||
REGISTER_GPU_HOST_KERNEL(ResourceHandle);
|
||||
|
|
|
@ -122,6 +122,7 @@ REGISTER_SYCL_HOST_KERNEL(bool);
|
|||
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
|
||||
REGISTER_GPU_KERNEL(Variant);
|
||||
REGISTER_GPU_KERNEL(bool);
|
||||
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
|
@ -157,7 +158,6 @@ REGISTER_GPU_KERNEL(Variant);
|
|||
IdentityOp)
|
||||
|
||||
REGISTER_GPU_HOST_KERNEL(int32);
|
||||
REGISTER_GPU_HOST_KERNEL(bool);
|
||||
REGISTER_GPU_HOST_KERNEL(tstring);
|
||||
REGISTER_GPU_HOST_KERNEL(ResourceHandle);
|
||||
|
||||
|
|
|
@ -252,8 +252,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
|
|||
|
||||
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_uint32(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
|
|
@ -73,9 +73,9 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
|||
for attr_key in new_node.attr:
|
||||
if attr_key == "parallel_iterations":
|
||||
new_node.attr[attr_key].i = 1
|
||||
elif new_node.op == "Switch":
|
||||
# We don't check the inputs to Switch ops as their inputs may be
|
||||
# Send/Recv nodes.
|
||||
elif new_node.op == "Switch" or new_node.op == "Identity":
|
||||
# We don't check the inputs to Switch or Identity ops as their inputs
|
||||
# may be Send/Recv nodes.
|
||||
del new_node.input[:]
|
||||
|
||||
return output_graph_def
|
||||
|
|
|
@ -396,10 +396,10 @@ class CondTest(test_util.TensorFlowTestCase):
|
|||
fn2=lambda: math_ops.add(y, 23))
|
||||
self.assertEquals(self.evaluate(z), 24)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.run_v1_only("Exercises Ref variables")
|
||||
def testCondModifyBoolPred(self):
|
||||
# This test in particular used to fail only when running in GPU, hence
|
||||
# use_gpu=True.
|
||||
# We want to use the GPU here because we want to ensure that we can update
|
||||
# a boolean ref variable on the GPU.
|
||||
with test_util.use_gpu():
|
||||
bool_var = variable_scope.get_variable(
|
||||
"bool_var", dtype=dtypes.bool, initializer=True)
|
||||
|
|
Loading…
Reference in New Issue