2013-07-08 6 views
5

Sto cercando di capire come velocizzare una funzione Python che usa numpy. L'output che ho ricevuto da lineprofiler è sotto, e questo dimostra che la maggior parte del tempo viene speso sulla linea ind_y, ind_x = np.where(seg_image == i).Accelera numpy.where per l'estrazione di segmenti interi?

seg_image è un array intero che è il risultato del segmentare un'immagine, trovando quindi i pixel dove seg_image == i estrae uno specifico oggetto segmentato. Sto eseguendo il looping di molti di questi oggetti (nel codice qui sotto sto solo eseguendo il ciclo 5 per il test, ma eseguirò un ciclo di oltre 20.000) e ci vorrà molto tempo per essere eseguito!

Esiste un modo per velocizzare la chiamata np.where? O, in alternativa, che la penultima riga (che prende anche una buona parte del tempo) può essere accelerata?

La soluzione ideale sarebbe quella di eseguire il codice sull'intera matrice in una volta, piuttosto che in loop, ma non penso che questo sia possibile in quanto vi sono effetti collaterali ad alcune delle funzioni che devo eseguire (per Ad esempio, la dilatazione di un oggetto segmentato può renderlo "collide" con la regione successiva e quindi fornire risultati errati in seguito).

Qualcuno ha qualche idea?

Line #  Hits   Time Per Hit % Time Line Contents 
============================================================== 
    5           def correct_hot(hot_image, seg_image): 
    6   1  239810 239810.0  2.3  new_hot = hot_image.copy() 
    7   1  572966 572966.0  5.5  sign = np.zeros_like(hot_image) + 1 
    8   1  67565 67565.0  0.6  sign[:,:] = 1 
    9   1  1257867 1257867.0  12.1  sign[hot_image > 0] = -1 
    10           
    11   1   150 150.0  0.0  s_elem = np.ones((3, 3)) 
    12           
    13            #for i in xrange(1,seg_image.max()+1): 
    14   6   57  9.5  0.0  for i in range(1,6): 
    15   5  6092775 1218555.0  58.5   ind_y, ind_x = np.where(seg_image == i) 
    16           
    17             # Get the average HOT value of the object (really simple!) 
    18   5   2408 481.6  0.0   obj_avg = hot_image[ind_y, ind_x].mean() 
    19           
    20   5   333  66.6  0.0   miny = np.min(ind_y) 
    21             
    22   5   162  32.4  0.0   minx = np.min(ind_x) 
    23             
    24           
    25   5   369  73.8  0.0   new_ind_x = ind_x - minx + 3 
    26   5   113  22.6  0.0   new_ind_y = ind_y - miny + 3 
    27           
    28   5   211  42.2  0.0   maxy = np.max(new_ind_y) 
    29   5   143  28.6  0.0   maxx = np.max(new_ind_x) 
    30           
    31             # 7 is + 1 to deal with the zero-based indexing, + 2 * 3 to deal with the 3 cell padding above 
    32   5   217  43.4  0.0   obj = np.zeros((maxy+7, maxx+7)) 
    33           
    34   5   158  31.6  0.0   obj[new_ind_y, new_ind_x] = 1 
    35           
    36   5   2482 496.4  0.0   dilated = ndimage.binary_dilation(obj, s_elem) 
    37   5   1370 274.0  0.0   border = mahotas.borders(dilated) 
    38           
    39   5   122  24.4  0.0   border = np.logical_and(border, dilated) 
    40           
    41   5   355  71.0  0.0   border_ind_y, border_ind_x = np.where(border == 1) 
    42   5   136  27.2  0.0   border_ind_y = border_ind_y + miny - 3 
    43   5   123  24.6  0.0   border_ind_x = border_ind_x + minx - 3 
    44           
    45   5   645 129.0  0.0   border_avg = hot_image[border_ind_y, border_ind_x].mean() 
    46           
    47   5  2167729 433545.8  20.8   new_hot[seg_image == i] = (new_hot[ind_y, ind_x] + (sign[ind_y, ind_x] * np.abs(obj_avg - border_avg))) 
    48   5  10179 2035.8  0.1   print obj_avg, border_avg 
    49           
    50   1   4  4.0  0.0  return new_hot 

risposta

4

EDIT ho lasciato la mia risposta originale al fondo per la cronaca, ma ho davvero guardato nel codice in modo più dettagliato durante il pranzo, e penso che l'utilizzo di np.where è un grosso errore:

In [63]: a = np.random.randint(100, size=(1000, 1000)) 

In [64]: %timeit a == 42 
1000 loops, best of 3: 950 us per loop 

In [65]: %timeit np.where(a == 42) 
100 loops, best of 3: 7.55 ms per loop 

Si potrebbe ottenere un array booleano (che è possibile utilizzare per l'indicizzazione) in 1/8 del tempo necessario per ottenere le coordinate effettive dei punti !!!

C'è naturalmente il ritaglio delle caratteristiche che si fanno, ma ndimage ha una funzione find_objects che restituisce fette che racchiudono, e sembra essere molto veloce:

In [66]: %timeit ndimage.find_objects(a) 
100 loops, best of 3: 11.5 ms per loop 

Questo restituisce una lista di tuple di fette allegando tutti gli oggetti, nel 50% di tempo in più necessario per trovare gli indici di un singolo oggetto.

Potrebbe non funzionare fuori dalla scatola come non posso provarlo in questo momento, ma vorrei ristrutturare il codice in qualcosa di simile al seguente:

def correct_hot_bis(hot_image, seg_image): 
    # Need this to not index out of bounds when computing border_avg 
    hot_image_padded = np.pad(hot_image, 3, mode='constant', 
           constant_values=0) 
    new_hot = hot_image.copy() 
    sign = np.ones_like(hot_image, dtype=np.int8) 
    sign[hot_image > 0] = -1 
    s_elem = np.ones((3, 3)) 

    for j, slice_ in enumerate(ndimage.find_objects(seg_image)): 
     hot_image_view = hot_image[slice_] 
     seg_image_view = seg_image[slice_] 
     new_shape = tuple(dim+6 for dim in hot_image_view.shape) 
     new_slice = tuple(slice(dim.start, 
           dim.stop+6, 
           None) for dim in slice_) 
     indices = seg_image_view == j+1 

     obj_avg = hot_image_view[indices].mean() 

     obj = np.zeros(new_shape) 
     obj[3:-3, 3:-3][indices] = True 

     dilated = ndimage.binary_dilation(obj, s_elem) 
     border = mahotas.borders(dilated) 
     border &= dilated 

     border_avg = hot_image_padded[new_slice][border == 1].mean() 

     new_hot[slice_][indices] += (sign[slice_][indices] * 
            np.abs(obj_avg - border_avg)) 

    return new_hot 

si sarebbe ancora bisogno di capire la collisioni, ma si potrebbe ottenere circa un 2x speed-up calcolando tutti gli indici contemporaneamente utilizzando un approccio basato np.unique:

a = np.random.randint(100, size=(1000, 1000)) 

def get_pos(arr): 
    pos = [] 
    for j in xrange(100): 
     pos.append(np.where(arr == j)) 
    return pos 

def get_pos_bis(arr): 
    unq, flat_idx = np.unique(arr, return_inverse=True) 
    pos = np.argsort(flat_idx) 
    counts = np.bincount(flat_idx) 
    cum_counts = np.cumsum(counts) 
    multi_dim_idx = np.unravel_index(pos, arr.shape) 
    return zip(*(np.split(coords, cum_counts) for coords in multi_dim_idx)) 

In [33]: %timeit get_pos(a) 
1 loops, best of 3: 766 ms per loop 

In [34]: %timeit get_pos_bis(a) 
1 loops, best of 3: 388 ms per loop 

Nota che i pixel per ciascun o gli oggetti vengono restituiti in un ordine diverso, quindi non è possibile confrontare semplicemente i rendimenti di entrambe le funzioni per valutare l'uguaglianza. Ma dovrebbero entrambi restituire lo stesso.

+0

Questo è meraviglioso, fantastico e sorprendente - grazie! La prima volta che l'ho eseguito ho scoperto che era in realtà più lento del mio codice originale, ma poi ho modificato parte del codice in modo che facesse tutto il lavoro (dilatazione, bordi, ecc.) In un piccolo array piuttosto che nell'array enorme - modificando come è stato calcolato new_shape. Ora ho avuto un enorme aumento di velocità. Su una delle immagini con cui sto lavorando, la vecchia versione impiegava due ore e mezza, la nuova versione impiegava 11 secondi! – robintw

+0

Oops! Sì, sembra che l'espressione del generatore debba essere 'new_shape = tuple (dim + 6 per dim in hot_image_view.shape)', e non 'new_shape = tuple (dim + 6 per dim in hot_image.shape)'. È quello che hai cambiato? Per favore, sentiti libero di modificare la mia risposta per riflettere il codice di lavoro. – Jaime

2

Una cosa si potrebbe fare per lo stesso un po 'di tempo è quello di salvare il risultato di seg_image == i in modo che non c'è bisogno di calcolare due volte. Lo stai calcolando sulle linee 15 & 47, potresti aggiungere seg_mask = seg_image == i e riutilizzare quel risultato (Potrebbe anche essere utile separare quel pezzo per scopi di profilazione).

Mentre ci sono alcune altre cose minori che è possibile fare per ottenere un po 'di prestazioni, il problema di root è che si sta utilizzando un algoritmo O (M * N) dove M è il numero di segmenti e N è la dimensione della tua immagine. Non è ovvio per me dal tuo codice se esiste un algoritmo più veloce per realizzare la stessa cosa, ma è il primo posto in cui proverei a cercare un aumento della velocità.