Accelerating Python code with Numba and LLVM
Graham Markall
Compiler Engineer, Embecosm
Twitter: @gmarkall
Graham Markall
Compiler Engineer, Embecosm
Twitter: @gmarkall
My background:
Now: Compiler Engineer at Embecosm - GNU Toolchains
Background in Python libraries for HPC (PyOP2, Firedrake)
This talk is an overview of:
A tool that makes Python code go faster by specialising and compiling it.
Random selection of users:
# Mandelbrot function in Python
def mandel(x, y, max_iters):
c = complex(x,y)
z = 0j
for i in range(max_iters):
z = z*z + c
if z.real * z.real + z.imag * z.imag >= 4:
return 255 * i // max_iters
return 255
# Mandelbrot function in Python using Numba
from numba import jit
@jit
def mandel(x, y, max_iters):
c = complex(x,y)
z = 0j
for i in range(max_iters):
z = z*z + c
if z.real * z.real + z.imag * z.imag >= 4:
return 255 * i // max_iters
return 255
CPython | 1x |
Numpy array-wide operations | 13x |
Numba (CPU) | 120x |
Numba (NVidia Tesla K20c) | 2100x |
Times in msec:
Example | CPython | Numba | Speedup |
---|---|---|---|
Black-Scholes | 969 | 433 | 2.2x |
Check Neighbours | 550 | 28 | 19.9x |
IS Distance | 372 | 70 | 5.4x |
Pairwise | 62 | 12 | 5.1x |
Calling a @jit
function:
- Yes: retrieve the compiled code from the cache
- No: compile a new specialisation
@jit
def add(a, b):
return a + b
def add_python(a, b):
return a + b
>>> %timeit add(1, 2)
10000000 loops, best of 3: 163 ns per loop
>>> %timeit add_python(1, 2)
10000000 loops, best of 3: 85.3 ns per loop
float32 + float32 -> float32
int32 + float32 -> float64
def f(a, b): # a:= float32, b:= float32
c = a + b # c:= float32
return c # return := float32
Example typing 1:
def select(a, b, c): # a := float32, b := float32, c := bool
if c:
ret = a # ret := float32
else:
ret = b # ret := float32
return ret # return := {float32, float32}
# => float32
Example typing 2:
def select(a, b, c): # a := tuple(int32, int32), b := float32,
# c := bool
if c:
ret = a # ret := tuple(int32, int32)
else:
ret = b # ret := float32
return ret # return := {tuple(int32, int32), float32}
# => XXX
llvmlite user community (examples):
- M-Labs Artiq - control system for quantum information experiments
- PPC - Python Pascal Compiler
- Various university compilers courses
- Numba!
Lightweight interface to LLVM though IR parser
IR builder reimplemented in pure Python
LLVM versions 3.5 - 3.8 supported
Inside functions decorated with @jit:
Also inside functions decorated with @jit:
Classes cannot be decorated with @jit.
Types:
- int, bool, float, complex
- tuple, list, None
- bytes, bytearray, memoryview (and other buffer-like objects)
Built-in functions:
- abs, enumerate, len, min, max, print, range, round, zip
Standard library:
- cmath, math, random, ctypes...
Third-party:
- cffi, numpy
Comprehensive list: http://numba.pydata.org/numba-doc/0.21.0/reference/pysupported.html
All kinds of arrays: scalar and structured type
- except when containing Python objects
Allocation, iterating, indexing, slicing
Reductions: argmax(), max(), prod() etc.
Scalar types and values (including datetime64 and timedelta64)
Array expressions, but no broadcasting
See reference manual: http://numba.pydata.org/numba-doc/0.21.0/reference/numpysupported.html
@vectorize
def rel_diff(x, y):
return 2 * (x - y) / (x + y)
Call:
a = np.arange(1000, dtype = float32)
b = a * 2 + 1
rel_diff(a, b)
@guvectorize([(int64[:], int64[:], int64[:])], '(n),()->(n)')
def g(x, y, res):
for i in range(x.shape[0]):
res[i] = x[i] + y[0]
(n),()->(n)
->
: Inputs, not allocated. After: outputs, allocatedMatrix-vector products:
@guvectorize([(float64[:, :], float64[:], float64[:])],
'(m,n),(n)->(m)')
def batch_matmul(M, v, y):
pass # ...
Fixed outputs (e.g. max and min):
@guvectorize([(float64[:], float64[:], float64[:])],
'(n)->(),()')
def max_min(arr, largest, smallest):
pass # ...
@jit
def sum_strings(arr):
intarr = np.empty(len(arr), dtype=np.int32)
for i in range(len(arr)):
intarr[i] = int(arr[i])
sum = 0
# Lifted loop
for i in range(len(intarr)):
sum += intarr[i]
return sum
Start off with just jitting it and see if it runs
Use numba --annotate-html to see what Numba sees
Start adding nopython=True to your innermost functions
Try to fix each function and then move on
- Need to make sure all inputs, outputs, are Numba-compatible types
- No lists, dicts, etc
Don't forget to assess performance at each state
@jit(float64(float64, float64))
def add(a, b):
return a + b
float64(float64, float64)
probably unnecessary!for i in range(len(X)):
Y[i] = sin(X[i])
for i in range(len(Y)):
Z[i] = Y[i] * Y[i]
for i in range(len(X)):
Y[i] = sin(X[i])
Z[i] = Y[i] * Y[i]
for i in range(len(X)):
Y = sin(X[i])
Z[i] = Y * Y
NUMBA_DISABLE_JIT=1
to disable compilation@numba.jit(nogil=True)
def my_function(x, y, z):
...
examples/nogil.py
in the Numba distributionIncluding:
@vectorize([float64(float64, float64)])
def rel_diff_serial(x, y):
return 2 * (x - y) / (x + y)
@vectorize(([float64(float64, float64)]), target='parallel')
def rel_diff_parallel(x, y):
return 2 * (x - y) / (x + y)
For 10^8 elements, on my laptop (i7-2620M, 2 cores + HT):
%timeit rel_diff_serial(x, y)
# 1 loop, best of 3: 556 ms per loop
%timeit rel_diff_parallel(x, y)
# 1 loop, best of 3: 272 ms per loop
target='parallel'
or target=cuda
to @vectorize
decorator@vectorize(target='parallel')
)@vectorize([args], target='parallel')
Dispatch based on argument:
1-norm for scalar, vector and matrix:
def scalar_1norm(x):
'''Absolute value of x'''
return math.fabs(x)
def vector_1norm(x):
'''Sum of absolute values of x'''
return np.sum(np.abs(x))
def matrix_1norm(x):
'''Max sum of absolute values of columns of x'''
colsums = np.zeros(x.shape[1])
for i in range(len(colsums)):
colsums[i] = np.sum(np.abs(x[:, i]))
return np.max(colsums)
JITting into a single function using @generated_jit
:
def bad_1norm(x):
raise TypeError("Unsupported type for 1-norm")
@generated_jit(nopython=True)
def l1_norm(x):
if isinstance(x, types.Number):
return scalar_1norm
if isinstance(x, types.Array) and x.ndim == 1:
return vector_1norm
elif isinstance(x, types.Array) and x.ndim == 2:
return matrix_1norm
else:
return bad_1norm
Calling the generated function:
# Calling
x0 = np.random.rand()
x1 = np.random.rand(M)
x2 = np.random.rand(M * N).reshape(M, N)
l1_norm(x0)
l1_norm(x1)
l1_norm(x2)
# TypeError("Unsupported type for 1-norm")
l1_norm(np.zeros((10, 10, 10))
numba.types
to see types and attributesArray
, Number
, Integer
, Float
, List
ndim
, array dtype
, tuple dtype
or
types
Buffer
is the base for a lot of things, including Array
None
:File "/home/pydata/anaconda3/envs/pydata/lib/python3.5/inspect.py", line 2156,
in _signature_from_callable
raise TypeError('{!r} is not a callable object'.format(obj))
TypeError: None is not a callable object
Original AoS layout using a structured dtype:
dtype = [
('x', np.float64),
('y', np.float64),
('z', np.float64),
('w', np.int32)
]
aos = np.zeros(N, dtype)
@jit(nopython=True)
def set_x_aos(v):
for i in range(len(v)):
v[i]['x'] = i
set_x_aos(aos)
vector_spec = [
('N', int32),
('x', float64[:]),
('y', float64[:]),
('z', float64[:]),
('w', int32[:])
]
@jitclass(vector_spec)
class VectorSoA(object):
def __init__(self, N):
self.N = N
self.x = np.zeros(N, dtype=np.float64)
self.y = np.zeros(N, dtype=np.float64)
self.z = np.zeros(N, dtype=np.float64)
self.w = np.zeros(N, dtype=np.int32)
soa = VectorSoA(N)
# Example iterating over x with the AoS layout:
@jit(nopython=True)
def set_x_aos(v):
for i in range(len(v)):
v[i]['x'] = i
# Example iterating over x with the SoA layout:
@jit(nopython=True)
def set_x_soa(v):
for i in range(v.N):
v.x[i] = i
@jit
function_
or __
not supported yet - see PR #1851np.int32
, assigning np.float64
:numba.errors.LoweringError: Failed at nopython
(nopython mode backend)
Internal error:
TypeError: Can only insert i32* at [4] in
{i8*, i8*, i64, i64, i32*, [1 x i64], [1 x i64]}:
got float*
Two modes:
Note: this is an example of a general procedure to wrap a library and use it with Numba. The demo won't run without VML development files.
Accelerate from Continuum provides VML functions as ufuncs.
register_module
register_type
to tell Numba how to map the typeffi.from_buffer
does not type checkInterval
class):