2014-12-15 18 views
5

Non vedo cosa c'è di sbagliato nel mio codice per la regressione lineare regolarizzata. Non regolarizzata Ho semplicemente questo, che sono ragionevolmente certo è corretto:Regressione lineare di Numpy con regolarizzazione

import numpy as np 

def get_model(features, labels): 
    return np.linalg.pinv(features).dot(labels) 

Ecco il mio codice per una soluzione regolarizzata, dove io non sto vedendo ciò che è sbagliato con esso:

def get_model(features, labels, lamb=0.0): 
    n_cols = features.shape[1] 
    return linalg.inv(features.transpose().dot(features) + lamb * np.identity(n_cols))\ 
      .dot(features.transpose()).dot(labels) 

Con il valore predefinito di 0.0 per lamb, la mia intenzione è che dovrebbe dare lo stesso risultato della versione (corretta) non regolamentata, ma la differenza è in realtà piuttosto grande.

Qualcuno vede qual è il problema?

+0

sto iniziando regolarizzazione, e sarebbe regolarizzare una linea di regressione lineare produce una curva? – duldi

+1

No. otterrai comunque coefficienti lineari. La regolarizzazione cambierà solo la pendenza. –

risposta

6

Il problema è:

features.transpose().dot(features) potrebbero non essere invertibile. E numpy.linalg.inv funziona solo per la matrice full-rank in base ai documenti. Tuttavia, un termine di regolarizzazione (diverso da zero) semplifica sempre l'equazione non singolare.

A proposito, avete ragione riguardo all'implementazione. Ma non è efficiente. Un modo efficace per risolvere questa equazione è il metodo dei minimi quadrati.

np.linalg.lstsq(features, labels) può fare il lavoro per np.linalg.pinv(features).dot(labels).

In linea generale, si può fare questo

def get_model(A, y, lamb=0): 
    n_col = A.shape[1] 
    return np.linalg.lstsq(A.T.dot(A) + lamb * np.identity(n_col), A.T.dot(y)) 
+0

Se si utilizza np.linalg.lstsq(), come si inserisce nel termine di regolarizzazione 'agnello'? –

+0

modifica la mia risposta. – nullas

+0

Che funziona bene! Grazie. Ho finito con 'np.linalg.lstsq (...) [0]' perché altrimenti ho ricevuto una tupla. Inoltre, sai per caso perché 'lstsq()' è più performante? –

Problemi correlati