Ho scritto una funzione Python che calcola interazioni elettromagnetiche a coppie tra un numero maggiore (N ~ 10^3) di particelle e memorizza i risultati in un complesso NxN128 ndarray. Funziona, ma è la parte più lenta di un programma più grande, impiegando circa 40 secondi quando N = 900 [corretto]. Il codice originale è simile al seguente:L'accelerazione di Cython non è grande come previsto
import numpy as np
def interaction(s,alpha,kprop): # s is an Nx3 real array
# alpha is complex
# kprop is float
ndipoles = s.shape[0]
Amat = np.zeros((ndipoles,3, ndipoles, 3), dtype=np.complex128)
I = np.array([[1,0,0],[0,1,0],[0,0,1]])
im = complex(0,1)
k2 = kprop*kprop
for i in range(ndipoles):
xi = s[i,:]
for j in range(ndipoles):
if i != j:
xj = s[j,:]
dx = xi-xj
R = np.sqrt(dx.dot(dx))
n = dx/R
kR = kprop*R
kR2 = kR*kR
A = ((1./kR2) - im/kR)
nxn = np.outer(n, n)
nxn = (3*A-1)*nxn + (1-A)*I
nxn *= -alpha*(k2*np.exp(im*kR))/R
else:
nxn = I
Amat[i,:,j,:] = nxn
return(Amat.reshape((3*ndipoles,3*ndipoles)))
non avevo mai usato in precedenza Cython, ma che sembrava un buon punto di partenza nel mio sforzo di velocizzare le cose, così ho praticamente alla cieca adattato le tecniche che ho trovato in online esercitazioni. Ho avuto un po 'di accelerazione (30 secondi contro 40 secondi), ma non così drammatico come mi aspettavo, quindi mi chiedo se sto facendo qualcosa di sbagliato o mi manca un passaggio critico. Quanto segue è il mio miglior tentativo di cythonizing la routine di cui sopra:
import numpy as np
cimport numpy as np
DTYPE = np.complex128
ctypedef np.complex128_t DTYPE_t
def interaction(np.ndarray s, DTYPE_t alpha, float kprop):
cdef float k2 = kprop*kprop
cdef int i,j
cdef np.ndarray xi, xj, dx, n, nxn
cdef float R, kR, kR2
cdef DTYPE_t A
cdef int ndipoles = s.shape[0]
cdef np.ndarray Amat = np.zeros((ndipoles,3, ndipoles, 3), dtype=DTYPE)
cdef np.ndarray I = np.array([[1,0,0],[0,1,0],[0,0,1]])
cdef DTYPE_t im = complex(0,1)
for i in range(ndipoles):
xi = s[i,:]
for j in range(ndipoles):
if i != j:
xj = s[j,:]
dx = xi-xj
R = np.sqrt(dx.dot(dx))
n = dx/R
kR = kprop*R
kR2 = kR*kR
A = ((1./kR2) - im/kR)
nxn = np.outer(n, n)
nxn = (3*A-1)*nxn + (1-A)*I
nxn *= -alpha*(k2*np.exp(im*kR))/R
else:
nxn = I
Amat[i,:,j,:] = nxn
return(Amat.reshape((3*ndipoles,3*ndipoles)))
Numpy è una libreria di C. E usa BLAS per fare l'algebra, quindi è piuttosto veloce. Non capisco davvero come funziona internals cython, ma essendo già numpy codice C, il guadagno di velocità è in qualsiasi cosa "non numpy". –
Supponevo che un numero sufficiente di operazioni line-by-line all'interno del ciclo annidato richiedesse l'invocazione diretta dell'interprete Python e che tali linee fossero quindi probabilmente il costo dominante relativo a Numpy - ma forse no? –
Puoi provare a digitare i tuoi array numpy, in modo che il compilatore conosca i tipi all'interno degli array. Non sono sicuro di quanto grande sarà la differenza, però. Potresti voler eseguire un profiler sul codice Python per vedere dove stai effettivamente perdendo la velocità. Se la maggior parte del tempo viene impiegata in routine di numpy, non si otterrà molto usando cython. – cel