3x3 median filter, or branchless loopless sort
So while thinking of something I could add to socles to make it worth uploading I thought of implementing a median filter.
Median filters are a bit of a pain because one has to sort all of the pixel values (for a greyscale image at least) before getting the result - and sorts generally require loops and other conditional code which make things a bit inefficient when executed on a GPU.
So somehow i'm back to looking into sorts again ... the never-ending quest continues.
Fortunately, this is another case where a batchers sort or other hand-rolled sorting networks come to the rescue. And since each element is being processed independently there are no inter-thread dependencies and all the processing can be unrolled and performed on registers (if you have enough of them).
So I implemented the 9 element sorting network from Knuth Volume 3 section 5.3.4 which sorts 9 elements in 25 compare and swap steps. But since 3 of the comparisons of the 25 aren't involved in the median value I also dropped them and ended up with 22 inline comparison and swaps to find the median value. Fortunately this is a small enough problem to fit entirely into registers too.
This also works pretty well in C - an albeit-inline version - manages to perform 100M median calculations on a 9 integer array (on the same array which is copied over) in 1.26s on my workstation, using the glibc qsort() to do the same task takes 17s. (I didn't verify that the inline version was transferring data to/from memory as it should in the micro-benchmark, but that does sound about right all things considered).
So, given unsorted values s0-s8
, and an operator cas(a, b)
(compare and swap) which results in a < b
, the median can be calculated using the following 22 steps:
cas(s1, s2); cas(s4, s5); cas(s7, s8); cas(s0, s1); cas(s3, s4); cas(s6, s7); cas(s1, s2); cas(s4, s5); cas(s7, s8); cas(s3, s6); cas(s4, s7); cas(s5, s8); cas(s0, s3); cas(s1, s4); cas(s2, s5); cas(s3, s6); cas(s4, s7); cas(s1, s3); cas(s2, s6); cas(s2, s3); cas(s4, s6); cas(s3, s4);
s4
now contains the median value. cas
can be implemented using 1 compare and 2 select
s.
But perhaps the nicest thing is the code is directly vectorisable.
Some timings using a 4 channel 32-bit input image 640x480 and producing a 4 channel 32-bit output image: For 1 element it takes 84uS, for 3 elements 175uS and 4 elements 238uS. This is on a NVidia 480GTX.
The current implementation: median.cl.