2012-05-24 18 views
7

Ho bisogno di implementare una funzione per sommare gli elementi di un array con una lunghezza di sezione variabile. Così,cython numpy accumulate function

a = np.arange(10) 
section_lengths = np.array([3, 2, 4]) 
out = accumulate(a, section_lengths) 
print out 
array([ 3., 7., 35.]) 

ho tentato un'implementazione in cython qui:

https://gist.github.com/2784725

per le prestazioni che sto paragonando al numpy soluzione pura per il caso in cui i section_lengths sono tutti uguali:

LEN = 10000 
b = np.ones(LEN, dtype=np.int) * 2000 
a = np.arange(np.sum(b), dtype=np.double) 
out = np.zeros(LEN, dtype=np.double) 

%timeit np.sum(a.reshape(-1,2000), axis=1) 
10 loops, best of 3: 25.1 ms per loop 

%timeit accumulate.accumulate(a, b, out) 
10 loops, best of 3: 64.6 ms per loop 

avete qualche suggerimento per migliorare le prestazioni?

+0

ho implementato diversi suggerimenti, vedere versione aggiornata su github: https://gist.github.com/2784725/8e2aaebbaa68c67e7a0686e9c7927f2f5b6f419a, ancora ci vuole 63ms, quindi non significativo miglioramento –

+2

Potrebbe essere off the point, ma ho pensato di menzionare ... Numpy ha già qualcosa di simile ad esso per * all * ufuncs. 'np.add.reduceat (a, section_lengths.cumsum())'. Deve essere cambiato un po '(cumsum manca di 0 all'inizio e si ottiene la fetta finale extra) e probabilmente si può battere la velocità con cython, ma è una caratteristica/trucco molto bella. – seberg

risposta

2

Si potrebbe provare alcuni dei seguenti:

  • Oltre alla direttiva @cython.boundscheck(False) compilatore, prova anche l'aggiunta di @cython.wraparound(False)

  • Nello script setup.py, prova ad aggiungere in alcuni flag di ottimizzazione:

    ext_modules = [Extension("accumulate", ["accumulate.pyx"], extra_compile_args=["-O3",])]

  • Date un'occhiata al file HTML generato da cython -a accumulate.pyx per vedere se ci sono sezioni che mancano tipizzazione statica o dipendono pesantemente sulle chiamate C-API Python:

    http://docs.cython.org/src/quickstart/cythonize.html#determining-where-to-add-types

  • aggiungere un'istruzione return alla fine del metodo. Attualmente sta eseguendo una serie di inutili verifiche degli errori nel tuo ciclo stretto al numero i_el += 1.

  • Non sono sicuro se farà la differenza, ma tendono a rendere contatori di ciclo cdef unsigned int piuttosto che solo int

Si potrebbe anche confrontare il codice per NumPy quando section_lengths sono disuguali, dal momento che probabilmente richiederà un po 'più di un semplice sum.

+0

grazie! Ho implementato tutti i tuoi suggerimenti, ma ancora non ci sono miglioramenti significativi. Grazie per aver suggerito cython -a, non lo sapevo. Ho aggiunto un'istruzione return, che mostra alcuni controlli strani eseguiti dal codice, vedere https: //gist.github.it/2784725 # gistcomment-330807 –

+0

Sto accettando questa risposta perché fornisce suggerimenti utili, ma nessuno di essi offre miglioramenti significativi. Cambierò la risposta accettata nel caso qualcun altro trovi qualcosa di meglio. –

1

Nell'aggiornamento del ciclo di ripetizione out[i_bas] è lento, è possibile creare una variabile temporanea per eseguire l'accumulo e aggiornare out[i_bas] al termine del ciclo di loop. Il seguente codice sarà veloce come versione NumPy:

import numpy as np 
cimport numpy as np 

ctypedef np.int_t DTYPE_int_t 
ctypedef np.double_t DTYPE_double_t 

cimport cython 
@cython.boundscheck(False) 
@cython.wraparound(False) 
def accumulate(
     np.ndarray[DTYPE_double_t, ndim=1] a not None, 
     np.ndarray[DTYPE_int_t, ndim=1] section_lengths not None, 
     np.ndarray[DTYPE_double_t, ndim=1] out not None, 
     ): 
    cdef int i_el, i_bas, sec_length, lenout 
    cdef double tmp 
    lenout = out.shape[0] 
    i_el = 0 
    for i_bas in range(lenout): 
     tmp = 0 
     for sec_length in range(section_lengths[i_bas]): 
      tmp += a[i_el] 
      i_el+=1 
     out[i_bas] = tmp 
+0

grazie! seguito il tuo suggerimento ma non c'è alcun miglioramento significativo, ho aggiornato la mia versione su github –

Problemi correlati