Quantum Python: Animating the Schrodinger Equation

Update: a reader contributed some improvements to the Python code presented below. Please see the pySchrodinger github repository for updated code

In a previous post I explored the new animation capabilities of the latest matplotlib release. It got me wondering whether it would be possible to simulate more complicated physical systems in real time in python. Quantum Mechanics was the first thing that came to mind. It turns out that by mixing a bit of Physics knowledge with a bit of computing knowledge, it's quite straightforward to simulate and animate a simple quantum mechanical system with python.

The Schrodinger Equation

The dynamics of a one-dimensional quantum system are governed by the time-dependent Schrodinger equation:

$$ i\hbar\frac{\partial \psi}{\partial t} = \frac{-\hbar^2}{2m} \frac{\partial^2 \psi}{\partial x^2} + V \psi $$

The wave function $\psi$ is a function of both position $x$ and time $t$, and is the fundamental description of the realm of the very small. Imagine we are following the motion of a single particle in one dimension. This wave function represents a probability of measuring the particle at a position $x$ at a time $t$. Quantum mechanics tells us that (contrary to our familiar classical reasoning) this probability is not a limitation of our knowledge of the system, but a reflection of an unavoidable uncertainty about the position and time of events in the realm of the very small.

Still, this equation is a bit opaque, but to visualize the results we'll need to solve this numerically. We'll approach this using the split-step Fourier method.

The Split-step Fourier Method

A standard way to numerically solve certain differential equations is through the use of the Fourier transform. We'll use a Fourier convention of the following form:

$$ \widetilde{\psi}(k, t) = \frac{1}{\sqrt{2\pi}} \int_{-\infty}^{\infty} \psi(x, t) e^{-ikx} dx $$

Under this convention, the associated inverse Fourier Transform is given by:

$$ \psi(x, t) = \frac{1}{\sqrt{2\pi}} \int_{-\infty}^{\infty} \widetilde{\psi}(k, t) e^{ikx} dk $$

Substituting this into the Schrodinger equation and simplifying gives the Fourier-space form of the Schrodinger equation:

$$ i\hbar\frac{\partial \widetilde{\psi}}{\partial t} = \frac{\hbar^2 k^2}{2m} \widetilde{\psi} + V(i\frac{\partial}{\partial k})\widetilde{\psi} $$

The two versions of the Schrodinger equation contain an interesting symmetry: the time step in each case depends on a straightforward multiplication of the wave function $\psi$, as well as a more complicated term involving derivatives with respect to $x$ or $k$. The key observation is that while the equation is difficult to evaluate fully within one of the forms, each basis offers a straightforward calculation of one of the two contributions. This suggests an efficient strategy to numerically solve the Schrodinger Equation.

First we solve the straightforward part of the $x$-space Schrodinger equation:

$$ i\hbar\frac{\partial \psi}{\partial t} = V(x) \psi $$

For a small time step $\Delta t$, this has a solution of the form

$$ \psi(x, t + \Delta t) = \psi(x, t) e^{-i V(x) \Delta t / \hbar} $$

Second, we solve the straightforward part of the $k$-space Schrodinger equation:

$$ i\hbar\frac{\partial \widetilde{\psi}}{\partial t} = \frac{\hbar^2 k^2}{2 m} \widetilde{\psi} $$

For a small time step $\Delta t$, this has a solution of the form

$$ \widetilde{\psi}(k, t + \Delta t) = \widetilde{\psi}(k, t) e^{-i \hbar k^2 \Delta t / 2m} $$

Numerical Considerations

Solving this system numerically will require repeated computations of the Fourier transform of $\psi(x, t)$ and the inverse Fourier transform of $\widetilde{\psi}(k, t)$. The best-known algorithm for computation of numerical Fourier transforms is the Fast Fourier Transform (FFT), which is available in scipy and efficiently computes the following form of the discrete Fourier transform:

$$ \widetilde{F_m} = \sum_{n=0}^{N-1} F_n e^{-2\pi i n m / N} $$

and its inverse

$$ F_n = \frac{1}{N} \sum_{m=0}^{N-1} \widetilde{F_m} e^{2\pi i n m / N} $$

We need to know how these relate to the continuous Fourier transforms defined and used above. Let's take the example of the forward transform. Assume that the infinite integral is well-approximated by the finite integral from $a$ to $b$, so that we can write

$$ \widetilde{\psi}(k, t) = \frac{1}{\sqrt{2\pi}} \int_a^b \psi(x, t) e^{-ikx} dx $$

This approximation ends up being equivalent to assuming that the potential $V(x) \to \infty$ at $x \le a$ and $x \ge b$. We'll now approximate this integral as a Riemann sum of $N$ terms, and define $\Delta x = (b - a) / N$, and $x_n = a + n\Delta x$:

$$ \widetilde{\psi}(k, t) \simeq \frac{1}{\sqrt{2\pi}} \sum_{n=0}^{N-1} \psi(x_n, t) e^{-ikx_n} \Delta x $$

This is starting to look like the discrete Fourier transform! To bring it even closer, let's define $k_m = k_0 + m\Delta k$, with $\Delta k = 2\pi / (N\Delta x)$. Then our approximation becomes

$$ \widetilde{\psi}(k_m, t) \simeq \frac{1}{\sqrt{2\pi}} \sum_{n=0}^{N-1} \psi(x_n, t) e^{-ik_m x_n} \Delta x $$

(Note that just as we have limited the range of $x$ above, we have here limited the range of $k$ as well. This means that high-frequency components of the signal will be lost in our approximation. The Nyquist sampling theorem tells us that this is an unavoidable consequence of choosing discrete steps in space, and it can be shown that the spacing we chose above exactly satisfies the Nyquist limit if we choose $k_0 = - \pi / \Delta x$).

Plugging our expressions for $x_n$ and $k_m$ into the Fourier approximation and rearranging, we find the following:

$$ \left[\widetilde{\psi}(k_m, t) e^{i m x_0 \Delta k}\right] \simeq \sum_{n=0}^{N-1} \left[ \frac{\Delta x}{\sqrt{2\pi}} \psi(x_n, t) e^{-ik_0 x_n} \right] e^{-2\pi i m n / N} $$

Similar arguments from the inverse Fourier transform yield:

$$ \left[\frac{\Delta x}{\sqrt{2 \pi}} \psi(x_n, t) e^{-i k_0 x_n}\right] \simeq \frac{1}{N} \sum_{m=0}^{N-1} \left[\widetilde{\psi}(k_m, t) e^{-i m x_0 \Delta k} \right] e^{2\pi i m n / N} $$

Comparing these to the discrete Fourier transforms above, we find that the continuous Fourier pair

$$ \psi(x, t) \Longleftrightarrow \widetilde{\psi}(k, t) $$

corresponds to the discrete Fourier pair

$$ \frac{\Delta x}{\sqrt{2 \pi}} \psi(x_n, t) e^{-i k_0 x_n} \Longleftrightarrow \widetilde{\psi}(k_m, t) e^{-i m x_0 \Delta k} $$

subject to the approximations mentioned above. This allows a fast numerical evaluation of the Schrodinger equation.

Putting It All Together: the Algorithm

We now put this all together using the following algorithm

  1. Choose $a$, $b$, $N$, and $k_0$ as above, sufficient to represent the initial state of your wave function $\psi(x)$. (Warning: this is perhaps the hardest part of the entire solution. If limits in $x$ or $k$ are chosen which do not suit your problem, then the approximations used above can destroy the accuracy of the calculation!) Once these are chosen, then $\Delta x = (b - a) / N$ and $\Delta k = 2\pi / (b - a)$. Define $x_n = a + n \Delta x$ and $k_m = k_0 + m \Delta k$.

  2. Discretize the wave-functions on this grid. Let $\psi_n(t) = \psi(x_n, t)$, $V_n = V(x_n)$, and $\widetilde{\psi}_m = \widetilde{\psi}(k_m, t)$.

  3. To progress the system by a time-step $\Delta t$, perform the following:

  4. Compute a half-step in $x$: $\psi_n \longleftarrow \psi_n \exp[-i (\Delta t / 2) (V_n / \hbar)]$

  5. Calculate $\widetilde{\psi}_m$ from $\psi_n$ using the FFT.

  6. Compute a full-step in $k$: $\widetilde{\psi}_m \longleftarrow \widetilde{\psi}_m \exp[-i \hbar (k \cdot k) \Delta t / (2 m)]$

  7. Calculate $\psi_n$ from $\widetilde{\psi}_m$ using the inverse FFT.

  8. Compute a second half-step in $x$: $\psi_n \longleftarrow \psi_n \exp[-i (\Delta t / 2)(V_n / \hbar)]$

  9. Repeat step 3 until the desired time is reached.

Note that we have split the $x$-space time-step into two half-steps: this turns out to lead to a more stable numerical solution than performing the step all at once. Those familiar with numerical integration algorithms may recognize this as an example of the well-known leap-frog integration technique.

To test this out, I've written a python code which sets up a particle in a box with a potential barrier. The barrier is high enough that a classical particle would be unable to penetrate it. A quantum particle, however, can "tunnel" through, leading to a non-zero probability of finding the particle on the other side of the partition. This quantum tunneling effect lies at the core of technologies as diverse as electron microsopy, semiconding diodes, and perhaps even the future of low-powered transistors.

The animation of the result is below (for a brief introduction to the animation capabilities of python, see this post). The top panel shows the position-space wave function, while the bottom panel shows the momentum-space wave function.

Notice that the height of the potential barrier (denoted by the dashed line in the bottom panel) is far larger than the energy of the particle. Still, due to quantum effects, a small part of the wave function is able to tunnel through the barrier and reach the other side.

The python code used to generate this animation is included below. It's pretty long, but I've tried to comment extensively to make the algorithm more clear. If you're so inclined, you might try running the example and adjusting the potential or the input wave function to see the effect on the dynamics of the quantum system.

Schrodinger schrodinger.py download
"""
General Numerical Solver for the 1D Time-Dependent Schrodinger's equation.

author: Jake Vanderplas
email: vanderplas@astro.washington.edu
website: http://jakevdp.github.com
license: BSD
Please feel free to use and modify this, but keep the above information. Thanks!
"""

import numpy as np
from matplotlib import pyplot as pl
from matplotlib import animation
from scipy.fftpack import fft,ifft


class Schrodinger(object):
    """
    Class which implements a numerical solution of the time-dependent
    Schrodinger equation for an arbitrary potential
    """
    def __init__(self, x, psi_x0, V_x,
                 k0 = None, hbar=1, m=1, t0=0.0):
        """
        Parameters
        ----------
        x : array_like, float
            length-N array of evenly spaced spatial coordinates
        psi_x0 : array_like, complex
            length-N array of the initial wave function at time t0
        V_x : array_like, float
             length-N array giving the potential at each x
        k0 : float
            the minimum value of k.  Note that, because of the workings of the
            fast fourier transform, the momentum wave-number will be defined
            in the range
              k0 < k < 2*pi / dx
            where dx = x[1]-x[0].  If you expect nonzero momentum outside this
            range, you must modify the inputs accordingly.  If not specified,
            k0 will be calculated such that the range is [-k0,k0]
        hbar : float
            value of planck's constant (default = 1)
        m : float
            particle mass (default = 1)
        t0 : float
            initial tile (default = 0)
        """
        # Validation of array inputs
        self.x, psi_x0, self.V_x = map(np.asarray, (x, psi_x0, V_x))
        N = self.x.size
        assert self.x.shape == (N,)
        assert psi_x0.shape == (N,)
        assert self.V_x.shape == (N,)

        # Set internal parameters
        self.hbar = hbar
        self.m = m
        self.t = t0
        self.dt_ = None
        self.N = len(x)
        self.dx = self.x[1] - self.x[0]
        self.dk = 2 * np.pi / (self.N * self.dx)

        # set momentum scale
        if k0 == None:
            self.k0 = -0.5 * self.N * self.dk
        else:
            self.k0 = k0
        self.k = self.k0 + self.dk * np.arange(self.N)

        self.psi_x = psi_x0
        self.compute_k_from_x()

        # variables which hold steps in evolution of the
        self.x_evolve_half = None
        self.x_evolve = None
        self.k_evolve = None

        # attributes used for dynamic plotting
        self.psi_x_line = None
        self.psi_k_line = None
        self.V_x_line = None

    def _set_psi_x(self, psi_x):
        self.psi_mod_x = (psi_x * np.exp(-1j * self.k[0] * self.x)
                          * self.dx / np.sqrt(2 * np.pi))

    def _get_psi_x(self):
        return (self.psi_mod_x * np.exp(1j * self.k[0] * self.x)
                * np.sqrt(2 * np.pi) / self.dx)

    def _set_psi_k(self, psi_k):
        self.psi_mod_k = psi_k * np.exp(1j * self.x[0]
                                        * self.dk * np.arange(self.N))

    def _get_psi_k(self):
        return self.psi_mod_k * np.exp(-1j * self.x[0] * 
                                        self.dk * np.arange(self.N))
    
    def _get_dt(self):
        return self.dt_

    def _set_dt(self, dt):
        if dt != self.dt_:
            self.dt_ = dt
            self.x_evolve_half = np.exp(-0.5 * 1j * self.V_x
                                         / self.hbar * dt )
            self.x_evolve = self.x_evolve_half * self.x_evolve_half
            self.k_evolve = np.exp(-0.5 * 1j * self.hbar /
                                    self.m * (self.k * self.k) * dt)
    
    psi_x = property(_get_psi_x, _set_psi_x)
    psi_k = property(_get_psi_k, _set_psi_k)
    dt = property(_get_dt, _set_dt)

    def compute_k_from_x(self):
        self.psi_mod_k = fft(self.psi_mod_x)

    def compute_x_from_k(self):
        self.psi_mod_x = ifft(self.psi_mod_k)

    def time_step(self, dt, Nsteps = 1):
        """
        Perform a series of time-steps via the time-dependent
        Schrodinger Equation.

        Parameters
        ----------
        dt : float
            the small time interval over which to integrate
        Nsteps : float, optional
            the number of intervals to compute.  The total change
            in time at the end of this method will be dt * Nsteps.
            default is N = 1
        """
        self.dt = dt

        if Nsteps > 0:
            self.psi_mod_x *= self.x_evolve_half

        for i in xrange(Nsteps - 1):
            self.compute_k_from_x()
            self.psi_mod_k *= self.k_evolve
            self.compute_x_from_k()
            self.psi_mod_x *= self.x_evolve

        self.compute_k_from_x()
        self.psi_mod_k *= self.k_evolve

        self.compute_x_from_k()
        self.psi_mod_x *= self.x_evolve_half

        self.compute_k_from_x()

        self.t += dt * Nsteps


######################################################################
# Helper functions for gaussian wave-packets

def gauss_x(x, a, x0, k0):
    """
    a gaussian wave packet of width a, centered at x0, with momentum k0
    """ 
    return ((a * np.sqrt(np.pi)) ** (-0.5)
            * np.exp(-0.5 * ((x - x0) * 1. / a) ** 2 + 1j * x * k0))

def gauss_k(k,a,x0,k0):
    """
    analytical fourier transform of gauss_x(x), above
    """
    return ((a / np.sqrt(np.pi))**0.5
            * np.exp(-0.5 * (a * (k - k0)) ** 2 - 1j * (k - k0) * x0))


######################################################################
# Utility functions for running the animation

def theta(x):
    """
    theta function :
      returns 0 if x<=0, and 1 if x>0
    """
    x = np.asarray(x)
    y = np.zeros(x.shape)
    y[x > 0] = 1.0
    return y

def square_barrier(x, width, height):
    return height * (theta(x) - theta(x - width))

######################################################################
# Create the animation

# specify time steps and duration
dt = 0.01
N_steps = 50
t_max = 120
frames = int(t_max / float(N_steps * dt))

# specify constants
hbar = 1.0   # planck's constant
m = 1.9      # particle mass

# specify range in x coordinate
N = 2 ** 11
dx = 0.1
x = dx * (np.arange(N) - 0.5 * N)

# specify potential
V0 = 1.5
L = hbar / np.sqrt(2 * m * V0)
a = 3 * L
x0 = -60 * L
V_x = square_barrier(x, a, V0)
V_x[x < -98] = 1E6
V_x[x > 98] = 1E6

# specify initial momentum and quantities derived from it
p0 = np.sqrt(2 * m * 0.2 * V0)
dp2 = p0 * p0 * 1./80
d = hbar / np.sqrt(2 * dp2)

k0 = p0 / hbar
v0 = p0 / m
psi_x0 = gauss_x(x, d, x0, k0)

# define the Schrodinger object which performs the calculations
S = Schrodinger(x=x,
                psi_x0=psi_x0,
                V_x=V_x,
                hbar=hbar,
                m=m,
                k0=-28)

######################################################################
# Set up plot
fig = pl.figure()

# plotting limits
xlim = (-100, 100)
klim = (-5, 5)

# top axes show the x-space data
ymin = 0
ymax = V0
ax1 = fig.add_subplot(211, xlim=xlim,
                      ylim=(ymin - 0.2 * (ymax - ymin),
                            ymax + 0.2 * (ymax - ymin)))
psi_x_line, = ax1.plot([], [], c='r', label=r'$|\psi(x)|$')
V_x_line, = ax1.plot([], [], c='k', label=r'$V(x)$')
center_line = ax1.axvline(0, c='k', ls=':',
                          label = r"$x_0 + v_0t$")

title = ax1.set_title("")
ax1.legend(prop=dict(size=12))
ax1.set_xlabel('$x$')
ax1.set_ylabel(r'$|\psi(x)|$')

# bottom axes show the k-space data
ymin = abs(S.psi_k).min()
ymax = abs(S.psi_k).max()
ax2 = fig.add_subplot(212, xlim=klim,
                      ylim=(ymin - 0.2 * (ymax - ymin),
                            ymax + 0.2 * (ymax - ymin)))
psi_k_line, = ax2.plot([], [], c='r', label=r'$|\psi(k)|$')

p0_line1 = ax2.axvline(-p0 / hbar, c='k', ls=':', label=r'$\pm p_0$')
p0_line2 = ax2.axvline(p0 / hbar, c='k', ls=':')
mV_line = ax2.axvline(np.sqrt(2 * V0) / hbar, c='k', ls='--',
                      label=r'$\sqrt{2mV_0}$')
ax2.legend(prop=dict(size=12))
ax2.set_xlabel('$k$')
ax2.set_ylabel(r'$|\psi(k)|$')

V_x_line.set_data(S.x, S.V_x)

######################################################################
# Animate plot
def init():
    psi_x_line.set_data([], [])
    V_x_line.set_data([], [])
    center_line.set_data([], [])

    psi_k_line.set_data([], [])
    title.set_text("")
    return (psi_x_line, V_x_line, center_line, psi_k_line, title)

def animate(i):
    S.time_step(dt, N_steps)
    psi_x_line.set_data(S.x, 4 * abs(S.psi_x))
    V_x_line.set_data(S.x, S.V_x)
    center_line.set_data(2 * [x0 + S.t * p0 / m], [0, 1])

    psi_k_line.set_data(S.k, abs(S.psi_k))
    title.set_text("t = %.2f" % S.t)
    return (psi_x_line, V_x_line, center_line, psi_k_line, title)

# call the animator.  blit=True means only re-draw the parts that have changed.
anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=frames, interval=30, blit=True)


# uncomment the following line to save the video in mp4 format.  This
# requires either mencoder or ffmpeg to be installed on your system

#anim.save('schrodinger_barrier.mp4', fps=15, extra_args=['-vcodec', 'libx264'])

pl.show()

Edit: I changed the legend font properties to play nicely with older matplotlib versions. Thanks to Yann for pointing it out.

Comments