Skip to content

Commit

Permalink
CUDA: Remove scikit-cuda dependency
Browse files Browse the repository at this point in the history
This is done by copying the generated source of the argmin kernel.
No actual python code from skcuda was used, so we no longer need to
import it. This fixes the annoying CUBLAS warning and also makes
start-up faster since the CUBLAS instance never gets created.

Fixes #47
  • Loading branch information
ali1234 committed Jun 12, 2023
1 parent de7f0e1 commit b811428
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 12 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
],
extras_require={
'spellcheck': ['pyenchant'],
'CUDA': ['pycuda', 'scikit-cuda'],
'CUDA': ['pycuda'],
'OpenCL': ['pyopencl'],
'viewer': ['PyOpenGL'],
'profiler': ['plop'],
Expand Down
80 changes: 69 additions & 11 deletions teletext/vbi/patterncuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,17 @@
from pycuda.compiler import SourceModule
from pycuda.driver import ctx_flags

from .pattern import Pattern

cuda.init()
cudadevice = cuda.Device(0)
cudacontext = cudadevice.make_context(flags=ctx_flags.SCHED_YIELD)
atexit.register(cudacontext.pop)

import skcuda
from skcuda.misc import _get_minmax_kernel
skcuda.misc._global_cublas_allocator = cuda.mem_alloc

from .pattern import Pattern


class PatternCUDA(Pattern):

mod = SourceModule("""
correlate = SourceModule("""
__global__ void correlate(float *input, float *patterns, float *result, int range_low, int range_high)
{
int x = (threadIdx.x + (blockDim.x*blockIdx.x));
Expand All @@ -47,10 +43,72 @@ class PatternCUDA(Pattern):
result[ridx] += (d*d);
}
}
""")

correlate = mod.get_function("correlate")
argmin = _get_minmax_kernel(np.float32, "min")[1]
""").get_function("correlate")

# argmin from scikit-cuda/blob/master/skcuda/misc.py
argmin = SourceModule("""
Copyright (c) 2009-2019, Lev E. Givon. All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are
permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of
conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list
of conditions and the following disclaimer in the documentation and/or other materials
provided with the distribution.
Neither the name of Lev E. Givon nor the names of any contributors may be used to
endorse or promote products derived from this software without specific prior
written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <pycuda-complex.hpp>
__global__ void minmax_row_kernel(float* mat, float* target,
unsigned int* idx_target,
unsigned int width,
unsigned int height) {
__shared__ float max_vals[32];
__shared__ unsigned int max_idxs[32];
float cur_max = 3.4028235e+38;
unsigned int cur_idx = 0;
float val = 0;
for (unsigned int i = threadIdx.x; i < width; i += 32) {
val = mat[blockIdx.x * width + i];
if (val < cur_max) {
cur_max = val;
cur_idx = i;
}
}
max_vals[threadIdx.x] = cur_max;
max_idxs[threadIdx.x] = cur_idx;
__syncthreads();
if (threadIdx.x == 0) {
cur_max = 3.4028235e+38;
cur_idx = 0;
for (unsigned int i = 0; i < 32; i++)
if (max_vals[i] < cur_max) {
cur_max = max_vals[i];
cur_idx = max_idxs[i];
}
target[blockIdx.x] = cur_max;
idx_target[blockIdx.x] = cur_idx;
}
}
""").get_function("minmax_row_kernel")

def __init__(self, filename):
Pattern.__init__(self, filename)
Expand Down

0 comments on commit b811428

Please sign in to comment.