2015-08-03 15 views
5

Ho un grande csr_matrix e sono interessato ai primi dieci valori e ai loro indici ogni riga. Ma non ho trovato un modo decente per manipolare la matrice.Scipy.sparse.csr_matrix: come ottenere i primi dieci valori e indici?

Qui è la mia soluzione attuale e l'idea principale è quella di elaborarli riga per riga:

row = csr_matrix.getrow(row_number).toarray()[0].ravel() 
top_ten_indicies = row.argsort()[-10:] 
top_ten_values = row[row.argsort()[-10:]] 

In questo modo, i vantaggi di csr_matrix non è completamente utilizzato. È più simile a una soluzione di forza bruta.

+0

È difficile suggerire una soluzione migliore quando non ce n'è nemmeno una. La mia ipotesi è che dovrai lavorare con la versione densa, o lavorare riga per riga (probabilmente dal formato 'lil'). – hpaulj

+0

@hpaulj Aggiornato il problema, grazie – Patrick

+0

Ho trovato un'altra domanda SO che chiedeva i valori principali per l'intera matrice sparsa. Una delle risposte suggeriva di usare 'argpartion' come più veloce di' argsort'. Ma c'è ancora la domanda se si può meglio che iterare riga per riga. 'lil' e' csr' sono i 2 formati più veloci per quello. – hpaulj

risposta

6

Non vedo quali siano i vantaggi del formato csr in questo caso. Certo, tutti i valori diversi da zero vengono raccolti in un array .data, con gli indici delle colonne corrispondenti in .indices. Ma sono in blocchi di varia lunghezza. Ciò significa che non possono essere elaborati in parallelo o con i passi degli array numpy.

Una soluzione è il blocco di quei blocchi in blocchi di lunghezza comuni. Questo è ciò che fa .toarray(). Quindi puoi trovare i valori massimi con argsort(axis=1) or with argpartition`.

Un altro è quello di suddividerli in blocchi di righe e elaborare ciascuno di questi. Questo è quello che stai facendo con lo .getrow. Un altro modo per suddividerli è convertirlo nel formato lil ed elaborare le sottoliste degli array .data e .rows.

Una possibile terza opzione è utilizzare il metodo ufuncreduceat. Ciò consente di applicare i metodi ufuncreduction ai blocchi sequenziali di un array. Ci sono stabiliti ufunc come np.add che sfruttano questo. argsort non è una funzione del genere. Ma c'è un modo di costruire un ufunc da una funzione Python, e guadagnare una modesta velocità rispetto alla normale iterazione di Python. [Ho bisogno di cercare una domanda SO recente che illustri questo.]

Illustrerò un po 'di ciò con una funzione più semplice, somma su righe.

Se A2 è una matrice csr.

A2.sum(axis=1) # the fastest compile csr method 
A2.A.sum(axis=1) # same, but with a dense intermediary 
[np.sum(l.data) for l in A2] # iterate over the rows of A2 
[np.sum(A2.getrow(i).data) for i in range(A2.shape[0])] # iterate with index 
[np.sum(l) for l in A2.tolil().data] # sum the sublists of lil format 
np.add.reduceat(A2.data, A2.indptr[:-1]) # with reduceat 

A2.sum(axis=1) è implementato come una moltiplicazione matrice. Ciò non è rilevante per il problema di ordinamento, ma è comunque un modo interessante di esaminare il problema della somma. Ricorda che il formato csr è stato sviluppato per una moltiplicazione efficiente.

Per una mia matrice del campione corrente (creata per un'altra domanda SO rada)

<8x47752 sparse matrix of type '<class 'numpy.float32'>' 
    with 32 stored elements in Compressed Sparse Row format> 

alcuni momenti di confronto sono

In [694]: timeit np.add.reduceat(A2.data, A2.indptr[:-1]) 
100000 loops, best of 3: 7.41 µs per loop 

In [695]: timeit A2.sum(axis=1) 
10000 loops, best of 3: 71.6 µs per loop 

In [696]: timeit [np.sum(l) for l in A2.tolil().data] 
1000 loops, best of 3: 280 µs per loop 

Tutto il resto è 1ms o più.

Suggerisco concentrandosi sullo sviluppo della propria funzione one-fila, qualcosa di simile:

def max_n(row_data, row_indices, n): 
    i = row_data.argsort()[-n:] 
    # i = row_data.argpartition(-n)[-n:] 
    top_values = row_data[i] 
    top_indices = row_indices[i] # do the sparse indices matter? 
    return top_values, top_indices, i 

poi vedere come se si inserisce in uno di questi metodi di iterazione. tolil() sembra molto promettente.

Non ho affrontato la questione di come raccogliere questi risultati. Dovrebbero essere liste di liste, array con 10 colonne, un'altra matrice sparsa con 10 valori per riga, ecc.?


sorting each row of a large sparse & saving top K values & column index - interrogazione da diversi anni indietro, ma senza risposta.

Argmax of each row or column in scipy sparse matrix - Domanda recente che cerca argmax per righe di csr. Discuto di alcuni degli stessi problemi.

how to speed up loop in numpy? - esempio di come utilizzare np.frompyfunc per creare un ufunc. Non so se la funzione risultante ha il metodo .reduceat.

Increasing value of top k elements in sparse matrix - ottenere i primi elementi k di csr (non per riga). Caso per argpartition.


La sommatoria fila implementato con np.frompyfunc:

In [741]: def foo(a,b): 
    return a+b 
In [742]: vfoo=np.frompyfunc(foo,2,1) 
In [743]: timeit vfoo.reduceat(A2.data,A2.indptr[:-1],dtype=object).astype(float) 
10000 loops, best of 3: 26.2 µs per loop 

Ecco velocità rispettabile. Ma non riesco a pensare a un modo di scrivere una funzione binaria (richiede 2 argomenti) che implementerebbe argsort tramite riduzione. Quindi questo è probabilmente un deadend per questo problema.

+0

È fantastico !! – Patrick

0

Proprio per rispondere alla domanda iniziale (per le persone come me che hanno trovato questa domanda cerca di copy-paste), ecco una soluzione che utilizza multiprocessing sulla base di @ hpaulj suggerimento di conversione in lil_matrix, e l'iterazione di righe

from multiprocessing import Pool 

def _top_k(args): 
    """ 
    Helper function to process a single row of top_k 
    """ 
    data, row = args 
    data, row = zip(*sorted(zip(data, row), reverse=True)[:k]) 
    return data, row 

def top_k(m, k): 
    """ 
    Keep only the top k elements of each row in a csr_matrix 
    """ 
    ml = m.tolil() 
    with Pool() as p: 
     ms = p.map(_top_k, zip(ml.data, ml.rows)) 
    ml.data, ml.rows = zip(*ms) 
    return ml.tocsr() 
Problemi correlati