Optimizing burst search in python

Premature optimization is the root of all evil.

Donald Knuth

In this post, we’ll see how to optimize a python implementation of the sliding-window burst search algorithm (Fries 1998). We will start profiling the unoptimized code, then we’ll explore different python optimization techniques. Sections of this post are:

  1. What is the burst search algorithm?
  2. Preparing the data
  3. A pure-python implementation
  4. Finding the bottlenecks
  5. Memoryview in pure python
  6. Vectorizing with numpy
  7. Iterators and loop unwrapping
  8. Beyond pure python: Cython and Numba
  9. Conclusions.

Why I choose python? Well arguments are too many, but you can find a nice summary in Jake Vanderplas’ post Why Python is the Last Language You’ll Have To Learn. In a single sentence (quoting John Cook):

I’d rather do mathematics in a general programming language than do general programming in a mathematical language.

What is the burst search algorithm?

To give some context, in the analysis of freely-diffusing single-molecule fluorescence experiments, the burst search algorithm is the central step that allows identifying bursts of photons emitted by single molecules during their diffusion through a small excitation volume. Here we use a simplified burst search that saves only start and stop times of each burst. A complete real-world implementation can be found in FRETBursts, a burst analysis software (relevant code here and here).

Briefly, a burst search algorithm tries to identifies bursts in a long stream of events. In single-molecule fluorescence experiments, this stream is represented by photon arrival times (timestamps) with a resolution of a few nanoseconds.

A common algorithm, introduced by the Seidel group (Fries 1998), consists in using a sliding windows of duration $T$ and identifying bursts when at least $m$ photons lie in this window. The final step of selecting bursts by size (counts, or number of photons) is computationally inexpensive and it will be ignored in this post.

Numerically, we need to “slide” the window in discrete steps, and since photon arrival times are stochastic, it makes sense to place the windows start in correspondence with each timestamp $t_i$ and check if there are at least $m$ photons between $t_i$ and $t_i + T$.

But how can we “check if there are $\le m$ photons between $t_i$ and $t_i + T$”? We can take a window $T$ and test if it contains at least $m$ photon, or, we can take $m$ consecutive photons ($m$ fixed) and test if they lie in a time window $\le T$. The latter approach is much easier to implement and more efficient, as it requires looping through the timestamps only once. In this post we’ll follow this latter method.

For the sake of this post, we assume that each burst is characterized by only a start and stop time. Finally, since the number of bursts is not known in advance, we’ll use a list to collect the bursts found during the search.

Preparing the data

To test different burst search implementation we can use a single-molecule FRET dataset available on figshare. The file are in Photon-HDF5, so we can load its content with a HDF5 library, such as pytables.

For this post we only need to load the timestamps array, which is here converted in seconds:

In [1]:
import tables
import numpy as np
In [2]:
filename = "data/0023uLRpitc_NTP_20dT_0.5GndCl.hdf5"
In [3]:
with tables.open_file(filename) as h5file:
    timestamps = h5file.root.photon_data.timestamps.read()
    timestamps = timestamps*h5file.root.photon_data.timestamps_specs.timestamps_unit.read()
In [4]:
timestamps
Out[4]:
array([  1.83558750e-03,   2.35056250e-03,   3.67655000e-03, ...,
         5.99998296e+02,   5.99998472e+02,   5.99999442e+02])
In [5]:
timestamps.size
Out[5]:
2683962

A pure-python implementation

The algorithm previously described can be expressed quite naturally with a for-loop:

In [6]:
def burst_search(times, m, T):
    in_burst = False
    bursts = []
    for i in range(len(times) - m - 1):
        if times[i + m - 1] - times[i] <= T:
            if not in_burst:
                in_burst = True
                istart = i
        elif in_burst:
            in_burst = False
            bursts.append((times[istart], times[i+m-1]))
    return bursts

The code is straightforward to read. First note that in_burst is a state-variable telling whether we are inside a burst. With this in mind, the algorithm unfolds as follows:

  1. the $i$ variable loops over the timestamps index
  2. if the $m$ consecutive photons starting at $t_i$ are within a window $\le T$
    1. if a burst is not already started, start the burst and save the start time
  3. Otherwise, if we are inside a burst, stop the burst and save the stop time

Let’s run it. We will use typical values of m=10 (use 10 photons to compute the rate) and $T=100\;{\rm\mu s}$ throughout this post.

In [7]:
bursts_py = burst_search(timestamps, 10, 100e-6)
In [8]:
print('Number of bursts: ', len(bursts_py))
Number of bursts:  18529
In [9]:
%timeit burst_search(timestamps, 10, 100e-6)
1 loops, best of 3: 1.02 s per loop

So, we found 18529 bursts and the execution took around a second. This can be “fast enough” in some cases. However, when dealing with larger files (longer measurement, multi-spot, etc…) or when we need to interactively explore the effect of burst search parameters we need a faster burst search.

In this post we just want to push the limits of what we can achieve in python, and in the next sections I’ll show various optimization approaches.

Finding the bottlenecks

First step of any optimization is identifying the bottlenecks. To measure the execution time we can use the %prun magic in IPython which calls the standard python profiler and measures the time spent in each function call. In this case, however, a line-by-line measure is more insightful. Therefore we will use line_profiler, a package available through PIP or conda. For a more detailed overview of different profiling techniques see the excellent post by Cyrille Rossant Profiling and optimizing Python code.

Let’s run our function through line_profiler:

In [10]:
%load_ext line_profiler
In [11]:
%lprun -f burst_search burst_search(timestamps, 10, 100e-6)
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     1                                           def burst_search(times, m, T):
     2         1            2      2.0      0.0      in_burst = False
     3         1            1      1.0      0.0      bursts = []
     4   2683952      1203389      0.4     25.9      for i in range(len(times) - m - 1):
     5   2683951      2247401      0.8     48.4          if times[i + m - 1] - times[i] <= T:
     6    232747       106153      0.5      2.3              if not in_burst:
     7     18529         7966      0.4      0.2                  in_burst = True
     8     18529         8058      0.4      0.2                  istart = i
     9   2451204      1050190      0.4     22.6          elif in_burst:
    10     18529         8205      0.4      0.2              in_burst = False
    11     18529        16750      0.9      0.4              bursts.append((times[istart], times[i+m-1]))
    12         1            1      1.0      0.0      return bursts

The most unexpected result, for me, was finding that the simple branching (e.g. line 9) accounts for a significant 20% of execution time. Except for that, we see that looping over >2 million of elements and computing times[i+m] - times[i] <= T (line 5) at each cycle is where the function spends the most time. The list-appending, conversely, adds a negligible contribution since it is performed once per burst (not at each cycle).

Memoryview in pure-python

Numpy arrays are notoriously slow when accessed element by element. Therefore, a big portion of line 5 execution time may be due to numpy arrary indexing. We can prove this by expanding line 5 in multiple lines, separating the element access from comparison and branching:

In [12]:
def burst_search_profile(times, m, T):
    in_burst = False
    bursts = []
    for i in range(len(times) - m - 1):
        t2 = times[i + m - 1]
        t1 = times[i]
        rate_above_threshold = t2 - t1 <= T
        if rate_above_threshold:
            if not in_burst:
                in_burst = True
                istart = i
        elif in_burst:
            in_burst = False
            bursts.append((times[istart], times[i+m-1]))
    return bursts
In [13]:
%lprun -f burst_search_profile burst_search_profile(timestamps, 10, 100e-6)
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     1                                           def burst_search_profile(times, m, T):
     2         1           20     20.0      0.0      in_burst = False
     3         1            1      1.0      0.0      bursts = []
     4   2683952      1303409      0.5     14.8      for i in range(len(times) - m - 1):
     5   2683951      1806225      0.7     20.5          t2 = times[i + m - 1]
     6   2683951      1526900      0.6     17.3          t1 = times[i]
     7   2683951      1607583      0.6     18.2          rate_above_threshold = t2 - t1 <= T
     8   2683951      1284202      0.5     14.5          if rate_above_threshold:
     9    232747       111491      0.5      1.3              if not in_burst:
    10     18529         8883      0.5      0.1                  in_burst = True
    11     18529         8949      0.5      0.1                  istart = i
    12   2451204      1142937      0.5     12.9          elif in_burst:
    13     18529         8813      0.5      0.1              in_burst = False
    14     18529        20464      1.1      0.2              bursts.append((times[istart], times[i+m-1]))
    15         1            4      4.0      0.0      return bursts

We see that the array element access (lines 5 and 6) accounts for almost 40% of the total execution time.

A workaround for the slow element access is using a memoryview (a built-in python object) to access the data in the numpy array. Since a memoryview and a numpy array can share data we avoid wasting RAM and waiting for memory copying. To test performances, we just call burst_search() passing a memoryview instead of a numpy array:

In [14]:
%timeit burst_search(memoryview(timestamps), 10, 100e-6)
1 loops, best of 3: 617 ms per loop

This little trick alone provides more than 35% speed increase, rendering the element access time negligible. Note that it is not always possible to replace a numpy array with a memoryview because the latter does not support advanced numpy indexing and array operations. In this case we were lucky that we didn’t need any of these features.

Vectorizing with numpy

One step further, we can try removing the loop and replacing it with array operations (vectorization). In this case we cannot get rid of the for-loop (at least I cannot find a way to remove it), but we can move some heavy computation out of the loop.

As shown before, the bottleneck is the operation t[i+m] - t[i] <= T. In particular the array element access at each cycle is quite heavy.

At a cost of a higher memory usage we can perform this subtraction outside the loop:

In [15]:
def burst_search_numpy1(times, m, T):
    in_burst = False
    bursts = []
    delta_t = (times[m-1:] - times[:times.size-m+1])
    for i in range(times.size-m+1):
        if delta_t[i] <= T:
            if not in_burst:
                in_burst = True
                start = times[i]
        elif in_burst:
            in_burst = False
            bursts.append((start, times[i+m-1]))
    return bursts
In [16]:
bursts_numpy1 = burst_search_numpy1(timestamps, 10, 100e-6)
assert bursts_numpy1 == bursts_py
%timeit burst_search_numpy1(timestamps, 10, 100e-6)
1 loops, best of 3: 436 ms per loop

We achieved a ~2x speed improvement, not bad.

This approach can be improved moving also the comparison outside the loop:

In [17]:
def burst_search_numpy2(times, m, T):
    in_burst = False
    bursts = []
    above_min_rate = (times[m-1:] - times[:times.size-m+1]) <= T
    for i in range(len(times)-m+1):
        if above_min_rate[i]:
            if not in_burst:
                in_burst = True
                start = times[i]
        elif in_burst:
            in_burst = False
            bursts.append((start, times[i+m-1]))
    return bursts
In [18]:
bursts_numpy2 = burst_search_numpy2(timestamps, 10, 100e-6)
assert bursts_numpy2 == bursts_py
%timeit burst_search_numpy2(timestamps, 10, 100e-6)
1 loops, best of 3: 266 ms per loop

The execution time is now 3 times faster that the intial one, but we can do better.

Iterators and loop unwrapping

The last version moved most computations out of the loop, but we are still accessing a “big” numpy array (the boolean array above_min_rate) element by element. As seen before, this is relatively slow because because of all the fancy indexing numpy arrays support.

A slightly faster approach is using an iterator to access the numpy array element by element in a loop:

In [19]:
def burst_search_numpy3(times, m, T):
    in_burst = False
    bursts = []
    
    max_index = times.size-m+1
    above_min_rate = (times[m-1:] - times[:max_index]) <= T
    
    for i, above_min_rate_ in enumerate(above_min_rate):
        if above_min_rate_:
            if not in_burst:
                in_burst = True
                start = times[i]
        elif in_burst:
            in_burst = False
            bursts.append((start, times[i+m-1]))
    return bursts
In [20]:
bursts_numpy3 = burst_search_numpy3(timestamps, 10, 100e-6)
assert bursts_numpy3 == bursts_py
%timeit burst_search_numpy3(timestamps, 10, 100e-6)
1 loops, best of 3: 228 ms per loop

We can further optimize line 13, the elif branch that is executed at almost every cycle. To avoid this test we can slightly unwrap the loop in this way:

  • first we continue if not above the rate.
  • when “not continuing”, we are inside a burst and we do a mini internal loop until the burst is over
  • the main loop resumes from a position updated by the inner loop

The trick is using the same iterator in the two nested loops, sounds scary but it turns out to be quite simple:

In [21]:
def burst_search_numpy4(times, m, T):
    bursts = []
    
    max_index = times.size-m+1
    below_min_rate = (times[m-1:] - times[:max_index]) > T
    
    iter_i_belowrate = enumerate((below_min_rate))
    for i, below_min_rate_ in iter_i_belowrate:
        if below_min_rate_: continue           
        
        start = times[i]
        for i, below_min_rate_ in iter_i_belowrate:
            if below_min_rate_: break
        
        bursts.append((start, times[i+m-1]))
    return bursts
In [22]:
bursts_numpy4 = burst_search_numpy4(timestamps, 10, 100e-6)
assert bursts_numpy4 == bursts_py
%timeit burst_search_numpy4(timestamps, 10, 100e-6)
1 loops, best of 3: 206 ms per loop

As a last attempt, we can try using a memoryview instead of a numpy array for below_min_rate. With a memoryview the item access is fast so we can use a simpler iterator (iterate over the index):

In [23]:
def burst_search_numpy5(times, m, T):
    bursts = []
    below_min_rate = memoryview((times[m-1:] - times[:times.size-m+1]) > T)
    
    iter_i = iter(range(len(times)-m+1))
    for i in iter_i:
        if below_min_rate[i]:
            continue
        
        start = times[i]
        for i in iter_i:
            if below_min_rate[i]:
                break
        
        bursts.append((start, times[i+m-1]))
    return bursts
In [24]:
bursts_numpy5 = burst_search_numpy5(timestamps, 10, 100e-6)
assert bursts_numpy5 == bursts_py
%timeit burst_search_numpy5(timestamps, 10, 100e-6)
10 loops, best of 3: 191 ms per loop

Optimizations 3-5, cut another 25-30% to the execution time, achieving, overall, a 5 times faster execution (compared to the initial python version). More than speed, I find the latest version (burst_search_numpy5) the easiest to read, mainly because of the elimination of the state-variable in_burst and the use a simpler iterator. A rare case where optimization and readability don’t conflict.

Beyond pure-python: Cython and Numba

For even faster execution we need to bypass the interpreted step of the python language. In cases like this, in which we perform an “hot” loop with item access and branching inside the loop, we likely gain a significant speed-up if the python code is compiled.

To compile the previous function to machine code we can use cython. Cython extends the python syntax, and allows to statically translate python to C (that, eventually, is compiled to machine code).
To allow cython to produce an optimized C version, we need to specify the types of the variables used inside the loop.

The cython version of the previous algorithm is:

In [25]:
%load_ext Cython
In [26]:
%%cython
cimport numpy as np

def burst_search_cy(np.float64_t[:] times, np.int16_t m, np.float64_t T):
    cdef int i
    cdef np.int64_t istart
    cdef np.uint8_t in_bursts
    
    in_burst = 0
    bursts = []
    for i in range(times.size - m - 1):
        if times[i + m - 1] - times[i] <= T:
            if not in_burst:
                in_burst = 1
                istart = i
        elif in_burst:
            in_burst = 0
            bursts.append((times[istart], times[i+m-1]))
    return bursts

In the IPython evironment (inside Jupyter), we use the %%cython magic command (included in the cython package) to compile the function on-fly. Outside IPython, we can setup distutils (or setuptools) to handle the compilation step (see cython documentation).

In addition to the import line, I added types definitions for the function arguments. The cython types have the same name as numpy types with and additional _t. For the first argument, which is an array, I used the syntax np.float64_t[:] times that defines a Cython Memoryview (like a python’s memoryview but faster). To read more about memoryviews see this post from Jake Vanderplas: Memoryview Benchmarks.

Let’s run the cython function:

In [27]:
bursts_cy = burst_search_cy(timestamps, 10, 100e-6)
In [28]:
assert bursts_cy == bursts_py
In [29]:
%timeit burst_search_cy(timestamps, 10, 100e-6)
100 loops, best of 3: 8.68 ms per loop

Note that the execution time dropped to less than 10ms, a whooping 100x speed increase compared to the pure python version. And all we have done is adding a few variable declarations!

Another optimization tool emerged in the last couple of years is Numba. Numba is a Just-in-Time (JIT) compiler which analyzes code during execution and translates it to machine code on-fly (recent version also support static compilation, a la cython). Under the hood, Numba uses the LLVM compiler for the translation to machine code.

In principle, numba can perform more advanced optimizations than cython and there are reports of 2x speed improvements vs cython in special cases.

Numba is even easier to use than cython: we just need to add a single line to decorate the function (i.e. @numba.jit). Let see Numba at work:

In [30]:
import numba
In [31]:
@numba.jit
def burst_search_numba(times, m, T):
    in_burst = False
    bursts = []
    istart = 0
    for i in range(times.size - m - 1):
        if times[i + m - 1] - times[i] <= T:
            if not in_burst:
                in_burst = True
                istart = i
        elif in_burst:
            in_burst = False
            bursts.append((times[istart], times[i+m-1]))
    return bursts

Differently from cython, numba code is just python code with no special syntax.

Checking the execution time:

In [32]:
bursts_numba = burst_search_numba(timestamps, 10, 100e-6)
In [33]:
assert bursts_numba == bursts_py
In [34]:
%timeit burst_search_numba(timestamps, 10, 100e-6)
100 loops, best of 3: 8.21 ms per loop

we see that also the numba version runs in less than 10 ms (I would declare a tie with Cython).

As a small note, I had to add an additional line (istart = 0) to help Numba infer the type of the istart variable. Knowing that istart is an int64 allows Numba to do more aggressive optimizations.

There are some corner cases which Numba cannot optimize, but these are becoming fewer at each release. For example in version 0.20, Numba was not able to optimize the current example (which then runs at pure-python speeds). Conversely, with the latest Numba version (0.22 as of writing), we reached pure-C speed, probably thanks to the new list optimizations.

Conclusion

Starting from a pure-python implementation of the burst search, we used line_profiler to find the bottlenecks and optimized the code.

The burst search algorithm is not particularly keen to be vectorized. Nonetheless, using only pure-python tools (i.e. numpy vectorization, memoryview and iterators), we achieved a respectable 5x speed improvement. Next, using static compilation (Cython) or JIT compilation (Numba) we reached a 100x (100 fold!) speed-up.

With all these options for optimization in python, it can be easy to fall in the trap of premature optimization. To avoid it, always perform optimizations as the last step, focusing on the bottlenecks highlighted by a profiler. When in doubt, prefer clarity over speed.

See also

This post is written in Jupyter Notebook. The original notebook can be downloaded here. Content released as CC BY 2.0.

Comments !

blogroll

social