2015-07-02 12 views
5

Ho una matrice sparsa che viene trasformata da sklearn tfidfVectorier. Credo che alcune righe siano tutte a zero righe. Voglio rimuoverli. Tuttavia, per quanto ne so, le funzioni built-in esistenti, ad es. nonzero() ed eliminate_zero(), si concentrano su voci zero, anziché su righe.matrice sparsa scipy: rimuovere le righe i cui elementi sono tutti uguali

C'è un modo semplice per rimuovere le righe tutto-zero da una matrice sparsa?

Esempio: Quello che ho ora (in realtà in formato sparse):

[ [0, 0, 0] 
    [1, 0, 2] 
    [0, 0, 1] ] 

Quello che voglio ottenere:

funzioni
[ [1, 0, 2] 
    [0, 0, 1] ] 

risposta

2

Non ci sono esistenti per questo, ma non è troppo male per scrivere il proprio:

def remove_zero_rows(M): 
    M = scipy.sparse.csr_matrix(M) 

Innanzitutto, convertire la matrice inFormato. Questo è importante perché le matrici CSR memorizzano i dati come una tripla di (data, indices, indptr), dove data contiene i valori diversi da zero, indices memorizza gli indici delle colonne e indptr contiene le informazioni sull'indice di riga. I documenti spiegano meglio:

gli indici colonna per riga i sono memorizzati in indices[indptr[i]:indptr[i+1]] ei loro valori corrispondenti sono immagazzinato in data[indptr[i]:indptr[i+1]].

Quindi, per trovare le righe senza valori diversi da zero, possiamo solo guardare i valori successivi di M.indptr. Continuando la nostra funzione dall'alto:

num_nonzeros = np.diff(M.indptr) 
    return M[num_nonzeros != 0] 

Il secondo vantaggio del formato CSR qui è che è relativamente poco costoso da affettare righe, che semplifica la creazione della matrice risultante.

1

Grazie per la risposta, @perimosocordiae

Ho appena trovato un'altra soluzione da me. Sto postando qui nel caso qualcuno possa averne bisogno in futuro.

def remove_zero_rows(X) 
    # X is a scipy sparse matrix. We want to remove all zero rows from it 
    nonzero_row_indice, _ = X.nonzero() 
    unique_nonzero_indice = numpy.unique(nonzero_row_indice) 
    return X[unique_nonzero_indice] 
5

affettare + getnnz() fa il trucco:

M = M[M.getnnz(1)>0] 

lavora direttamente sul csr_array. È possibile anche rimuovere tutti i 0 colonne senza formati cambiano:

M = M[:,M.getnnz(0)>0] 

Tuttavia, se si desidera rimuovere sia il necessario

M = M[M.getnnz(1)>0][:,M.getnnz(0)>0] #GOOD 

io non so perché, ma

M = M[M.getnnz(1)>0, M.getnnz(0)>0] #BAD 

non lo fa lavoro.

Problemi correlati