Motion Correction
Authors
Setup
Import Libraries
import numpy as np
import tifffile
import matplotlib.pyplot as plt
from skimage.registration import phase_cross_correlation
from scipy.ndimage import shift
import napariDownload Data
import owncloud
import os
if not os.path.exists('data'):
print('Creating directory for data')
os.mkdir('data')
if not os.path.exists('data/data.tif'):
oc = owncloud.Client.from_public_link('https://uni-bonn.sciebo.de/s/bFDLSfaxRqKlqT7')
oc.get_file('/', 'data/data.tif');Creating directory for datamovie = tifffile.imread('data/data.tif')
movie.shape(3000, 170, 170)Section 1: Recognizing Motion
Before beginning motion correction, it is important to first recognize the type and extent of motion present in the calcium imaging dataset. This section introduces the process of visually and quantitatively identifying motion artifacts and determining whether rigid correction is appropriate. You will also prepare the dataset by generating a reference frame and ensuring consistent intensity across frames. These preparatory steps are essential for enabling accurate and reliable correction in the following stages.
| Code | Description |
|---|---|
movie[:5] |
Access the first five elements of the movie array or list. |
movie[1:10] |
Access elements from index 1 to 9 of the movie array or list. |
Exercises
Example: Open and play movie in napari. Is there motion?
viewer = napari.Viewer()viewer.add_image(movie)Exercise: Open the first 200 frames of the movie. Is there motion?
Solution
viewer.add_image(movie[:200])Exercise: Open the frames between 1000 and 1200 of the movie. Is there motion?
Solution
viewer.add_image(movie[1000:1200])Example: Create mean projection of all the frames of the movie.
proj = movie.mean(axis=0)
plt.imshow(proj, cmap='gray')Exercise: Create mean projection of frames between 1000 and 1200 of the movie.
Solution
proj = movie[1000:1200].mean(axis=0)
plt.imshow(proj, cmap='gray')Exercise: Create mean projection of frames of the first 200 frames of the movie.
Solution
proj = movie[:200].mean(axis=0)
plt.imshow(proj, cmap='gray')Section 2: Section 2: Estimating Frame Shifts Relative to a Reference
To correct for motion, it is necessary to know how much each frame in the movie has shifted relative to a stable reference. This section introduces the concept of shift estimation by comparing each frame to the reference image. The outcome is a set of displacement values for each frame that can be used to realign the dataset. By the end of this section, you will understand how motion is quantified and how to structure this information for use in the correction step.
Each shift is a pair of values: [dy, dx], indicating how much the image needs to move to align with the reference frame.
dy (Vertical Shift) |
dx (Horizontal Shift) |
|
|---|---|---|
| Meaning | Vertical movement | Horizontal movement |
| Positive Value | Move frame downward | Move frame right |
| Negative Value | Move frame upward | Move frame left |
| Code | Description |
|---|---|
phase_cross_correlation(frame_ref, frame) |
Compute the phase cross-correlation between frame_ref and frame to determine the shift between them. |
shift(frame, shift=shift_val) |
Apply a shift to the frame by shift_val using the shift function, modifying the position of the frame. |
Exercises
Example: Compute shifts between the first frame and the zeroth frame.
shift_val, _, _ = phase_cross_correlation(movie[0], movie[1])
shift_valarray([0., 0.], dtype=float32)Exercise: Compute shifts between frame 0 and the last frame.
Solution
shift_val, _, _ = phase_cross_correlation(movie[0], movie[-1])
shift_valarray([1., 0.], dtype=float32)Exercise: Compute shifts between the 1000th frame and the 200th frame.
Solution
shift_val, _, _ = phase_cross_correlation(movie[1000], movie[200])
shift_valarray([1., 1.], dtype=float32)Example: What is the shift between the first frame and mean projection of all frames?
sum_frame = movie.mean(axis=0)
shift_val, _, _ = phase_cross_correlation(sum_frame, movie[0])
shift_valarray([0., 0.], dtype=float32)Exercise: What is the shift between the first frame and mean projection of frames between 1000 and 1200?
Solution
sum_frame = movie[1000:1200].mean(axis=0)
shift_val, _, _ = phase_cross_correlation(sum_frame, movie[0])
shift_valarray([ 0., -1.], dtype=float32)Exercise: What is the shift between the last frame and mean projection of frames between 1000 and 1200?
Solution
sum_frame = movie[1000:1200].mean(axis=0)
shift_val, _, _ = phase_cross_correlation(sum_frame, movie[-1])
shift_valarray([1., 0.], dtype=float32)Example: Align second frame with the first frame.
shift_val, _, _ = phase_cross_correlation(movie[0], movie[1])
aligned = shift(movie[1], shift=shift_val)
plt.subplot(1, 2, 1)
plt.imshow(movie[0], cmap='gray')
plt.title("Reference")
plt.subplot(1, 2, 2)
plt.imshow(aligned - movie[0], cmap='gray')
plt.title("Difference After Alignment")
plt.tight_layout()Exercise: Align 201th frame with the first frame.
Solution
shift_val, _, _ = phase_cross_correlation(movie[0], movie[200])
aligned = shift(movie[200], shift=shift_val)
plt.subplot(1, 2, 1)
plt.imshow(movie[0], cmap='gray')
plt.title("Reference")
plt.subplot(1, 2, 2)
plt.imshow(aligned - movie[0], cmap='gray')
plt.title("Difference After Alignment")
plt.tight_layout()Exercise: Align the last frame with the first frame.
Solution
shift_val, _, _ = phase_cross_correlation(movie[0], movie[-1])
aligned = shift(movie[-1], shift=shift_val)
plt.subplot(1, 2, 1)
plt.imshow(movie[0], cmap='gray')
plt.title("Reference")
plt.subplot(1, 2, 2)
plt.imshow(aligned - movie[0], cmap='gray')
plt.title("Difference After Alignment")
plt.tight_layout()Example: Align the first frame with mean projection of all frames.
sum_frame = movie.mean(axis=0)
shift_val, _, _ = phase_cross_correlation(sum_frame, movie[0])
aligned = shift(movie[0], shift=shift_val)
plt.subplot(1, 2, 1)
plt.imshow(sum_frame, cmap='gray')
plt.title("Reference")
plt.subplot(1, 2, 2)
plt.imshow(aligned - movie[0], cmap='gray')
plt.title("Difference After Alignment")
plt.tight_layout()Exercise: Align the first frame with mean projection of frames between 1000 and 1200.
Solution
sum_frame = movie[1000:1200].mean(axis=0)
shift_val, _, _ = phase_cross_correlation(sum_frame, movie[0])
aligned = shift(movie[0], shift=shift_val)
plt.subplot(1, 2, 1)
plt.imshow(sum_frame, cmap='gray')
plt.title("Reference")
plt.subplot(1, 2, 2)
plt.imshow(aligned - movie[0], cmap='gray')
plt.title("Difference After Alignment")
plt.tight_layout()Exercise: Align the last frame with mean projection of all frame between 1000 and 1200.
Solution
sum_frame = movie[1000:1200].mean(axis=0)
shift_val, _, _ = phase_cross_correlation(sum_frame, movie[-1])
aligned = shift(movie[-1], shift=shift_val)
plt.subplot(1, 2, 1)
plt.imshow(sum_frame, cmap='gray')
plt.title("Reference")
plt.subplot(1, 2, 2)
plt.imshow(aligned - movie[-1], cmap='gray')
plt.title("Difference After Alignment")
plt.tight_layout()Exercise: Estimate shifts for all frames
Solution
sum_frame = movie[1000:1200].mean(axis=0)
shifts = np.array([phase_cross_correlation(sum_frame, frame)[0] for frame in movie])Section 3: Motion Trace
The motion trace is a plot of frame-by-frame shifts along the X and Y directions.
A good trace should look have
- small fluctuations.
- no sudden jumps between neighboring frames.
Signs of outliers
- sudden spike in shift (Y or X direction).
- flatline followed by a jump could suggest a movement event or an error in frame reading.
- one frame with extreme values compared to neighbors which may need to be excluded or handled separately.
| Code | Description |
|---|---|
plt.plot(shifts) |
Plot the shifts array or list using Matplotlib to visualize the shift values. |
Exercises
Example: Plot motion trace by estimating shifts with respect to the first frame.
sum_frame = movie[0]
shifts = np.array([phase_cross_correlation(sum_frame, frame)[0] for frame in movie])
plt.plot(shifts);Exercise: Plot motion trace by estimating shifts with respect to the mean projection of all frames.
Solution
sum_frame = movie.mean(axis=0)
shifts = np.array([phase_cross_correlation(sum_frame, frame)[0] for frame in movie])
plt.plot(shifts);Exercise: Plot motion trace by estimating shifts with respect to the mean projection of frames between 1000 and 1200.
Solution
sum_frame = movie[1000:1200].mean(axis=0)
shifts = np.array([phase_cross_correlation(sum_frame, frame)[0] for frame in movie])
plt.plot(shifts);| Mode | What It Does | Expected Plot Behavior |
|---|---|---|
constant |
Pads shifted-in regions with a fixed value (default = 0) | Sudden spikes or drops in intensity at frames with large shifts |
nearest |
Extends the edge of the image by repeating the closest pixel | Smooth, closely tracks the original trace |
reflect |
Mirrors pixel values from inside the image at the edges | Smooth, but with slightly more variation |
Example: Apply constant border and compare intensity variation for top 5 pixels of original and motion corrected mean projections
sum_frame = movie[1000:1200].mean(axis=0)
shifts = np.array([phase_cross_correlation(sum_frame, frame)[0] for frame in movie])aligned = np.array([shift(f, shift=sh, mode='constant') for f, sh in zip(movie, shifts)])
border_orig = np.mean(movie[:, :5, :], axis=(1,2))
border_aligned = np.mean(aligned[:, :5, :], axis=(1,2))
plt.plot(border_orig, label='Original')
plt.plot(border_aligned, label='Aligned')
plt.legend()Exercise: Apply nearest border and compare intensity variation for top 5 pixels of original and motion corrected mean projections
Solution
aligned = np.array([shift(f, shift=sh, mode='nearest') for f, sh in zip(movie, shifts)])
top_orig = np.mean(movie[:, :5, :], axis=(1, 2))
top_corr = np.mean(aligned[:, :5, :], axis=(1, 2))
plt.plot(top_orig, label='Border Original')
plt.plot(top_corr, label='Border Aligned')
plt.legend()Exercise: Apply reflect border and compare intensity variation for top 5 pixels of original and motion corrected mean projections
Solution
aligned = np.array([shift(f, shift=sh, mode='reflect') for f, sh in zip(movie, shifts)])
top_orig = np.mean(movie[:, :5, :], axis=(1, 2))
top_corr = np.mean(aligned[:, :5, :], axis=(1, 2))
plt.plot(top_orig, label='Border Original')
plt.plot(top_corr, label='Border Aligned')
plt.legend()