Premature optimization is the root of all evil.
Donald Knuth
In this post, we’ll see how to optimize a python implementation of the sliding-window burst search algorithm (Fries 1998). We will start profiling the unoptimized code, then we’ll explore different python optimization techniques. Sections of this post are:
- What is the burst search algorithm?
- Preparing the data
- A pure-python implementation
- Finding the bottlenecks
- Memoryview in pure python
- Vectorizing with numpy
- Iterators and loop unwrapping
- Beyond pure python: Cython and Numba
- Conclusions.
Why I choose python? Well arguments are too many, but you can find a nice summary in Jake Vanderplas’ post Why Python is the Last Language You’ll Have To Learn. In a single sentence (quoting John Cook):
I’d rather do mathematics in a general programming language than do general programming in a mathematical language.
What is the burst search algorithm?¶
To give some context, in the analysis of freely-diffusing single-molecule fluorescence experiments, the burst search algorithm is the central step that allows identifying bursts of photons emitted by single molecules during their diffusion through a small excitation volume. Here we use a simplified burst search that saves only start and stop times of each burst. A complete real-world implementation can be found in FRETBursts, a burst analysis software (relevant code here and here).
Briefly, a burst search algorithm tries to identifies bursts in a long stream of events. In single-molecule fluorescence experiments, this stream is represented by photon arrival times (timestamps) with a resolution of a few nanoseconds.
A common algorithm, introduced by the Seidel group (Fries 1998), consists in using a sliding windows of duration $T$ and identifying bursts when at least $m$ photons lie in this window. The final step of selecting bursts by size (counts, or number of photons) is computationally inexpensive and it will be ignored in this post.
Numerically, we need to “slide” the window in discrete steps, and since photon arrival times are stochastic, it makes sense to place the windows start in correspondence with each timestamp $t_i$ and check if there are at least $m$ photons between $t_i$ and $t_i + T$.
But how can we “check if there are $\le m$ photons between
$t_i$ and $t_i + T$”? We can take a window $T$ and test if it
contains at least $m$ photon, or, we can take $m$ consecutive photons ($m$ fixed)
and test if they lie in a time window $\le T$.
The latter approach is much easier to implement and more
efficient, as it requires looping through the timestamps
only once.
In this post we’ll follow this latter method.
For the sake of this post, we assume that each burst is characterized by only a start and stop time. Finally, since the number of bursts is not known in advance, we’ll use a list to collect the bursts found during the search.
Preparing the data¶
To test different burst search implementation we can use a single-molecule FRET dataset available on figshare. The file are in Photon-HDF5, so we can load its content with a HDF5 library, such as pytables.
For this post we only need to load the timestamps
array, which is here converted in seconds:
import tables
import numpy as np
filename = "data/0023uLRpitc_NTP_20dT_0.5GndCl.hdf5"
with tables.open_file(filename) as h5file:
timestamps = h5file.root.photon_data.timestamps.read()
timestamps = timestamps*h5file.root.photon_data.timestamps_specs.timestamps_unit.read()
timestamps
timestamps.size
A pure-python implementation¶
The algorithm previously described can be expressed quite naturally with a for-loop:
def burst_search(times, m, T):
in_burst = False
bursts = []
for i in range(len(times) - m - 1):
if times[i + m - 1] - times[i] <= T:
if not in_burst:
in_burst = True
istart = i
elif in_burst:
in_burst = False
bursts.append((times[istart], times[i+m-1]))
return bursts
The code is straightforward to read. First note that in_burst
is a
state-variable telling whether we are inside a burst.
With this in mind, the algorithm unfolds as follows:
- the $i$ variable loops over the timestamps index
- if the $m$ consecutive photons starting at $t_i$ are within a window $\le T$
- if a burst is not already started, start the burst and save the start time
- Otherwise, if we are inside a burst, stop the burst and save the stop time
Let’s run it. We will use typical values of m=10
(use 10 photons to compute the
rate) and $T=100\;{\rm\mu s}$ throughout this post.
bursts_py = burst_search(timestamps, 10, 100e-6)
print('Number of bursts: ', len(bursts_py))
%timeit burst_search(timestamps, 10, 100e-6)
So, we found 18529 bursts and the execution took around a second. This can be “fast enough” in some cases. However, when dealing with larger files (longer measurement, multi-spot, etc…) or when we need to interactively explore the effect of burst search parameters we need a faster burst search.
In this post we just want to push the limits of what we can achieve in python, and in the next sections I’ll show various optimization approaches.
Finding the bottlenecks¶
First step of any optimization is identifying the
bottlenecks.
To measure the execution time we can use the %prun
magic in IPython
which calls the standard python profiler and measures the time spent
in each function call. In this case, however, a line-by-line measure
is more insightful. Therefore we will use
line_profiler
, a package available through PIP or conda.
For a more detailed overview of different profiling techniques
see the excellent post by Cyrille Rossant
Profiling and optimizing Python code.
Let’s run our function through line_profiler
:
%load_ext line_profiler
%lprun -f burst_search burst_search(timestamps, 10, 100e-6)
Line # Hits Time Per Hit % Time Line Contents
==============================================================
1 def burst_search(times, m, T):
2 1 2 2.0 0.0 in_burst = False
3 1 1 1.0 0.0 bursts = []
4 2683952 1203389 0.4 25.9 for i in range(len(times) - m - 1):
5 2683951 2247401 0.8 48.4 if times[i + m - 1] - times[i] <= T:
6 232747 106153 0.5 2.3 if not in_burst:
7 18529 7966 0.4 0.2 in_burst = True
8 18529 8058 0.4 0.2 istart = i
9 2451204 1050190 0.4 22.6 elif in_burst:
10 18529 8205 0.4 0.2 in_burst = False
11 18529 16750 0.9 0.4 bursts.append((times[istart], times[i+m-1]))
12 1 1 1.0 0.0 return bursts
The most unexpected result, for me, was finding that the simple branching
(e.g. line 9) accounts for a significant 20% of execution time.
Except for that, we see that looping over >2 million of elements and
computing times[i+m] - times[i] <= T
(line 5) at each cycle is where
the function spends the most time. The list-appending, conversely,
adds a negligible contribution since it is performed once per burst
(not at each cycle).
Memoryview in pure-python¶
Numpy arrays are notoriously slow when accessed element by element. Therefore, a big portion of line 5 execution time may be due to numpy arrary indexing. We can prove this by expanding line 5 in multiple lines, separating the element access from comparison and branching:
def burst_search_profile(times, m, T):
in_burst = False
bursts = []
for i in range(len(times) - m - 1):
t2 = times[i + m - 1]
t1 = times[i]
rate_above_threshold = t2 - t1 <= T
if rate_above_threshold:
if not in_burst:
in_burst = True
istart = i
elif in_burst:
in_burst = False
bursts.append((times[istart], times[i+m-1]))
return bursts
%lprun -f burst_search_profile burst_search_profile(timestamps, 10, 100e-6)
Line # Hits Time Per Hit % Time Line Contents
==============================================================
1 def burst_search_profile(times, m, T):
2 1 20 20.0 0.0 in_burst = False
3 1 1 1.0 0.0 bursts = []
4 2683952 1303409 0.5 14.8 for i in range(len(times) - m - 1):
5 2683951 1806225 0.7 20.5 t2 = times[i + m - 1]
6 2683951 1526900 0.6 17.3 t1 = times[i]
7 2683951 1607583 0.6 18.2 rate_above_threshold = t2 - t1 <= T
8 2683951 1284202 0.5 14.5 if rate_above_threshold:
9 232747 111491 0.5 1.3 if not in_burst:
10 18529 8883 0.5 0.1 in_burst = True
11 18529 8949 0.5 0.1 istart = i
12 2451204 1142937 0.5 12.9 elif in_burst:
13 18529 8813 0.5 0.1 in_burst = False
14 18529 20464 1.1 0.2 bursts.append((times[istart], times[i+m-1]))
15 1 4 4.0 0.0 return bursts
We see that the array element access (lines 5 and 6) accounts for almost 40% of the total execution time.
A workaround for the slow element access is using a memoryview
(a built-in python object)
to access the data in the numpy array.
Since a memoryview
and a numpy array can share data
we avoid wasting RAM and waiting for memory copying.
To test performances, we just call
burst_search()
passing a memoryview instead of a numpy array:
%timeit burst_search(memoryview(timestamps), 10, 100e-6)
This little trick alone provides more than 35% speed increase,
rendering the element access time negligible. Note that it is not
always possible to replace a numpy array with a memoryview
because
the latter does not support advanced numpy indexing and
array operations. In this case we were lucky that we didn’t need any
of these features.
Vectorizing with numpy¶
One step further, we can try removing the loop and replacing it with array operations (vectorization). In this case we cannot get rid of the for-loop (at least I cannot find a way to remove it), but we can move some heavy computation out of the loop.
As shown before, the bottleneck is the operation
t[i+m] - t[i] <= T
. In particular the array element access
at each cycle is quite heavy.
At a cost of a higher memory usage we can perform this subtraction outside the loop:
def burst_search_numpy1(times, m, T):
in_burst = False
bursts = []
delta_t = (times[m-1:] - times[:times.size-m+1])
for i in range(times.size-m+1):
if delta_t[i] <= T:
if not in_burst:
in_burst = True
start = times[i]
elif in_burst:
in_burst = False
bursts.append((start, times[i+m-1]))
return bursts
bursts_numpy1 = burst_search_numpy1(timestamps, 10, 100e-6)
assert bursts_numpy1 == bursts_py
%timeit burst_search_numpy1(timestamps, 10, 100e-6)
We achieved a ~2x speed improvement, not bad.
This approach can be improved moving also the comparison outside the loop:
def burst_search_numpy2(times, m, T):
in_burst = False
bursts = []
above_min_rate = (times[m-1:] - times[:times.size-m+1]) <= T
for i in range(len(times)-m+1):
if above_min_rate[i]:
if not in_burst:
in_burst = True
start = times[i]
elif in_burst:
in_burst = False
bursts.append((start, times[i+m-1]))
return bursts
bursts_numpy2 = burst_search_numpy2(timestamps, 10, 100e-6)
assert bursts_numpy2 == bursts_py
%timeit burst_search_numpy2(timestamps, 10, 100e-6)
The execution time is now 3 times faster that the intial one, but we can do better.
Iterators and loop unwrapping¶
The last version moved most computations out of the loop, but we are still
accessing a “big” numpy array (the boolean array above_min_rate
) element by element.
As seen before, this is relatively slow because because of all the fancy indexing
numpy arrays support.
A slightly faster approach is using an iterator to access the numpy array element by element in a loop:
def burst_search_numpy3(times, m, T):
in_burst = False
bursts = []
max_index = times.size-m+1
above_min_rate = (times[m-1:] - times[:max_index]) <= T
for i, above_min_rate_ in enumerate(above_min_rate):
if above_min_rate_:
if not in_burst:
in_burst = True
start = times[i]
elif in_burst:
in_burst = False
bursts.append((start, times[i+m-1]))
return bursts
bursts_numpy3 = burst_search_numpy3(timestamps, 10, 100e-6)
assert bursts_numpy3 == bursts_py
%timeit burst_search_numpy3(timestamps, 10, 100e-6)
We can further optimize line 13, the elif
branch that is executed at almost every cycle.
To avoid this test we can slightly unwrap the loop in this way:
- first we
continue
if not above the rate. - when “not continuing”, we are inside a burst and we do a mini internal loop until the burst is over
- the main loop resumes from a position updated by the inner loop
The trick is using the same iterator in the two nested loops, sounds scary but it turns out to be quite simple:
def burst_search_numpy4(times, m, T):
bursts = []
max_index = times.size-m+1
below_min_rate = (times[m-1:] - times[:max_index]) > T
iter_i_belowrate = enumerate((below_min_rate))
for i, below_min_rate_ in iter_i_belowrate:
if below_min_rate_: continue
start = times[i]
for i, below_min_rate_ in iter_i_belowrate:
if below_min_rate_: break
bursts.append((start, times[i+m-1]))
return bursts
bursts_numpy4 = burst_search_numpy4(timestamps, 10, 100e-6)
assert bursts_numpy4 == bursts_py
%timeit burst_search_numpy4(timestamps, 10, 100e-6)
As a last attempt, we can try using a memoryview
instead of a numpy array for
below_min_rate
. With a memoryview
the item access is fast so we can
use a simpler iterator (iterate over the index):
def burst_search_numpy5(times, m, T):
bursts = []
below_min_rate = memoryview((times[m-1:] - times[:times.size-m+1]) > T)
iter_i = iter(range(len(times)-m+1))
for i in iter_i:
if below_min_rate[i]:
continue
start = times[i]
for i in iter_i:
if below_min_rate[i]:
break
bursts.append((start, times[i+m-1]))
return bursts
bursts_numpy5 = burst_search_numpy5(timestamps, 10, 100e-6)
assert bursts_numpy5 == bursts_py
%timeit burst_search_numpy5(timestamps, 10, 100e-6)
Optimizations 3-5, cut another 25-30% to the execution time,
achieving, overall, a 5 times faster execution (compared to the initial python version).
More than speed, I find the latest version (burst_search_numpy5
) the easiest to read,
mainly because of the elimination of the state-variable in_burst
and the use a simpler iterator. A rare case where optimization and
readability don’t conflict.
Beyond pure-python: Cython and Numba¶
For even faster execution we need to bypass the interpreted step of the python language. In cases like this, in which we perform an “hot” loop with item access and branching inside the loop, we likely gain a significant speed-up if the python code is compiled.
To compile the previous function to machine code we can use cython.
Cython extends the python syntax, and allows to statically translate
python to C (that, eventually, is compiled to machine code).
To allow cython to produce an optimized C version, we need to specify
the types of the variables used inside the loop.
The cython version of the previous algorithm is:
%load_ext Cython
%%cython
cimport numpy as np
def burst_search_cy(np.float64_t[:] times, np.int16_t m, np.float64_t T):
cdef int i
cdef np.int64_t istart
cdef np.uint8_t in_bursts
in_burst = 0
bursts = []
for i in range(times.size - m - 1):
if times[i + m - 1] - times[i] <= T:
if not in_burst:
in_burst = 1
istart = i
elif in_burst:
in_burst = 0
bursts.append((times[istart], times[i+m-1]))
return bursts
In the IPython evironment (inside Jupyter), we use the %%cython
magic command
(included in the cython
package) to compile the function on-fly.
Outside IPython, we can setup distutils
(or setuptools
) to handle the
compilation step (see cython documentation).
In addition to the import line, I added types definitions for
the function arguments. The cython types have the same name as
numpy types with and additional _t
. For the first argument,
which is an array, I used the syntax np.float64_t[:] times
that defines a Cython Memoryview
(like a python’s memoryview but faster).
To read more about memoryviews see this post from Jake Vanderplas: Memoryview Benchmarks.
Let’s run the cython function:
bursts_cy = burst_search_cy(timestamps, 10, 100e-6)
assert bursts_cy == bursts_py
%timeit burst_search_cy(timestamps, 10, 100e-6)
Note that the execution time dropped to less than 10ms, a whooping 100x speed increase compared to the pure python version. And all we have done is adding a few variable declarations!
Another optimization tool emerged in the last couple of years is Numba. Numba is a Just-in-Time (JIT) compiler which analyzes code during execution and translates it to machine code on-fly (recent version also support static compilation, a la cython). Under the hood, Numba uses the LLVM compiler for the translation to machine code.
In principle, numba can perform more advanced optimizations than cython and there are reports of 2x speed improvements vs cython in special cases.
Numba is even easier to use than cython: we just need to add a single
line to decorate
the function (i.e. @numba.jit
). Let see Numba at work:
import numba
@numba.jit
def burst_search_numba(times, m, T):
in_burst = False
bursts = []
istart = 0
for i in range(times.size - m - 1):
if times[i + m - 1] - times[i] <= T:
if not in_burst:
in_burst = True
istart = i
elif in_burst:
in_burst = False
bursts.append((times[istart], times[i+m-1]))
return bursts
Differently from cython, numba code is just python code with no special syntax.
Checking the execution time:
bursts_numba = burst_search_numba(timestamps, 10, 100e-6)
assert bursts_numba == bursts_py
%timeit burst_search_numba(timestamps, 10, 100e-6)
we see that also the numba version runs in less than 10 ms (I would declare a tie with Cython).
As a small note, I had to add an additional line (istart = 0
)
to help Numba infer the type of the istart
variable. Knowing
that istart
is an int64 allows Numba to do more aggressive optimizations.
There are some corner cases which Numba cannot optimize, but these are becoming fewer at each release. For example in version 0.20, Numba was not able to optimize the current example (which then runs at pure-python speeds). Conversely, with the latest Numba version (0.22 as of writing), we reached pure-C speed, probably thanks to the new list optimizations.
Conclusion¶
Starting from a pure-python implementation of the burst search,
we used line_profiler
to find the bottlenecks and optimized the code.
The burst search algorithm is not particularly keen to be vectorized. Nonetheless, using only pure-python tools (i.e. numpy vectorization, memoryview and iterators), we achieved a respectable 5x speed improvement. Next, using static compilation (Cython) or JIT compilation (Numba) we reached a 100x (100 fold!) speed-up.
With all these options for optimization in python, it can be easy to fall in the trap of premature optimization. To avoid it, always perform optimizations as the last step, focusing on the bottlenecks highlighted by a profiler. When in doubt, prefer clarity over speed.
See also¶
FRETBursts A burst analysis software implementing the techniques illustrated in this post in a real-world scenario.
This post is written in Jupyter Notebook. The original notebook can be downloaded here. Content released as CC BY 2.0.
Comments !