2013-07-22 29 views
5

Sono relativamente nuovo a Python e ho un ciclo annidato per. Poiché i cicli for richiedono un po 'di tempo per essere eseguiti, sto cercando di capire come convalidare questo codice in modo che possa funzionare più velocemente.Vettorizzazione per loop NumPy

In questo caso, coord è un array tridimensionale in cui coord [x, 0, 0] e coord [x, 0, 1] sono numeri interi e coord [x, 0, 2] è 0 o 1. H è una matrice sparsa SciPy e x_dist, y_dist, z_dist e a sono tutti float.

# x_dist, y_dist, and z_dist are floats 
# coord is a num x 1 x 3 numpy array where num can go into the hundreds of thousands 
num = coord.shape[0]  
H = sparse.lil_matrix((num, num)) 
for i in xrange(num): 
    for j in xrange(num): 
     if (np.absolute(coord[i, 0, 0] - coord[j, 0, 0]) <= 2 and 
       (np.absolute(coord[i, 0, 1] - coord[j, 0, 1]) <= 1)): 

      x = ((coord[i, 0, 0] * x_dist + coord[i, 0, 2] * z_dist) - 
       (coord[j, 0, 0] * x_dist + coord[j, 0, 2] * z_dist)) 

      y = (coord[i, 0, 1] * y_dist) - (coord[j, 0, 1] * y_dist) 

      if a - 0.5 <= np.sqrt(x ** 2 + y ** 2) <= a + 0.5: 
       H[i, j] = -2.7 

Ho anche letto che la trasmissione con NumPy, mentre molto più veloce, può portare a grandi quantità di utilizzo della memoria da array temporaneo. Sarebbe meglio andare sulla via della vettorizzazione o provare ad usare qualcosa come Cython?

risposta

2

La natura del calcolo rende difficile la vettorizzazione con metodi numpy con cui ho familiarità. Penso che la soluzione migliore in termini di velocità e utilizzo della memoria sia cython. Tuttavia, è possibile ottenere un po 'di velocità usando numba. Ecco un esempio (nota che normalmente si utilizza autojit come decoratore):

import numpy as np 
from scipy import sparse 
import time 
from numba.decorators import autojit 
x_dist=.5 
y_dist = .5 
z_dist = .4 
a = .6 
coord = np.random.normal(size=(1000,1000,1000)) 

def run(coord, x_dist,y_dist, z_dist, a): 
    num = coord.shape[0]  
    H = sparse.lil_matrix((num, num)) 
    for i in xrange(num): 
     for j in xrange(num): 
      if (np.absolute(coord[i, 0, 0] - coord[j, 0, 0]) <= 2 and 
        (np.absolute(coord[i, 0, 1] - coord[j, 0, 1]) <= 1)): 

       x = ((coord[i, 0, 0] * x_dist + coord[i, 0, 2] * z_dist) - 
        (coord[j, 0, 0] * x_dist + coord[j, 0, 2] * z_dist)) 

       y = (coord[i, 0, 1] * y_dist) - (coord[j, 0, 1] * y_dist) 

       if a - 0.5 <= np.sqrt(x ** 2 + y ** 2) <= a + 0.5: 
        H[i, j] = -2.7 
    return H 

runaj = autojit(run) 

t0 = time.time() 
run(coord,x_dist,y_dist, z_dist, a) 
t1 = time.time() 
print 'First Original Runtime:', t1 - t0 

t0 = time.time() 
run(coord,x_dist,y_dist, z_dist, a) 
t1 = time.time() 
print 'Second Original Runtime:', t1 - t0 

t0 = time.time() 
run(coord,x_dist,y_dist, z_dist, a) 
t1 = time.time() 
print 'Third Original Runtime:', t1 - t0 

t0 = time.time() 
runaj(coord,x_dist,y_dist, z_dist, a) 
t1 = time.time() 
print 'First Numba Runtime:', t1 - t0 

t0 = time.time() 
runaj(coord,x_dist,y_dist, z_dist, a) 
t1 = time.time() 
print 'Second Numba Runtime:', t1 - t0 

t0 = time.time() 
runaj(coord,x_dist,y_dist, z_dist, a) 
t1 = time.time() 
print 'Third Numba Runtime:', t1 - t0 

ottengo questo output:

First Original Runtime: 21.3574919701 
Second Original Runtime: 15.7615520954 
Third Original Runtime: 15.3634860516 
First Numba Runtime: 9.87108802795 
Second Numba Runtime: 9.32944011688 
Third Numba Runtime: 9.32300305367 
+0

Grazie per il suggerimento! Tuttavia, quando provo a mettere questo in uno script (usando il decoratore @autojit) e il tempo con IPython (% timeit% eseguire Test.py), ottengo risultati che sono costantemente più lenti del normale Python. Hai idea del perché questo sta accadendo? – sonicxml

+0

@sonicxml Questo è interessante. Stai usando gli stessi dati del mio esempio? Autojit ha bisogno di compilare la tua funzione per ogni nuovo tipo di dati che gli passi, e lo fa in fase di runtime. Pertanto, per piccoli esempi potrebbe essere più lento a causa del tempo di compilazione. Potrebbe essere il problema dell'esempio che stai utilizzando? – jcrudy

+0

Ahh okay. Avevo eseguito un array più piccolo solo per testarlo, ma ora che ho reso l'array più grande numba sta diventando molto più veloce di python. – sonicxml

5

Questo è come mi sarei vettorizzare il codice, qualche discussione sui caveat tardi :

import numpy as np 
import scipy.sparse as sps 

idx = ((np.abs(coord[:, 0, 0] - coord[:, 0, 0, None]) <= 2) & 
     (np.abs(coord[:, 0, 1] - coord[:, 0, 1, None]) <= 1)) 

rows, cols = np.nonzero(idx) 
x = ((coord[rows, 0, 0]-coord[cols, 0, 0]) * x_dist + 
    (coord[rows, 0, 2]-coord[cols, 0, 2]) * z_dist) 
y = (coord[rows, 0, 1]-coord[cols, 0, 1]) * y_dist 
r2 = x*x + y*y 

idx = ((a - 0.5)**2 <= r2) & (r2 <= (a + 0.5)**2) 

rows, cols = rows[idx], cols[idx] 
data = np.repeat(2.7, len(rows)) 

H = sps.coo_matrix((data, (rows, cols)), shape=(num, num)).tolil() 

come annotato, i problemi stanno per venire con il primo idx matrice, in quanto sarà di forma (num, num), quindi WIL Probabilmente farò a pezzi la tua memoria se num è "tra le centinaia di migliaia".

Una possibile soluzione è quella di suddividere il problema in blocchi gestibili. Se si dispone di un array di 100.000 elementi, è possibile dividerlo in 100 blocchi di 1.000 elementi ed eseguire una versione modificata del codice sopra per ognuna delle 10.000 combinazioni di blocchi. Avresti solo bisogno di un array di 1.000.000 elementi idx (che potresti pre-allocare e riutilizzare per prestazioni migliori) e avresti un ciclo di soli 10.000 iterazioni, invece dei 10.000.000.000 della tua attuale implementazione. È una sorta di schema di parallelizzazione di un povero uomo, che puoi effettivamente migliorare facendo in modo che molti di questi blocchi vengano elaborati in parallelo se hai un computer multi-core.

+0

Wow! È molto più veloce del mio codice originale. Riguardo ai pezzi: nel mio codice originale, paragono ogni punto in comune con ogni altro punto in coord. Forse mi manca questo, ma quando rompo il codice in blocchi, come posso confrontare i punti tra i blocchi? – sonicxml

Problemi correlati