Ho due vettori di dati e li ho inseriti in matplotlib.scatter()
. Ora mi piacerebbe tracciare un adattamento lineare a questi dati. Come lo farei? Ho provato a utilizzare scikitlearn
e np.scatter
.Come sovrapporre una linea su un grafico a dispersione in python?
risposta
import numpy as np
import matplotlib.pyplot as plt
# sample data
x = np.arange(10)
y = 5*x + 10
# fit with np.polyfit
m, b = np.polyfit(x, y, 1)
plt.plot(x, y, '.')
plt.plot(x, m*x + b, '-')
io ho un debole per scikits.statsmodels. Qui un esempio:
import statsmodels.api as sm
import numpy as np
import matplotlib.pyplot as plt
X = np.random.rand(100)
Y = X + np.random.rand(100)*0.1
results = sm.OLS(Y,sm.add_constant(X)).fit()
print results.summary()
plt.scatter(X,Y)
X_plot = np.linspace(0,1,100)
plt.plot(X_plot, X_plot*results.params[0] + results.params[1])
plt.show()
L'unica parte difficile è sm.add_constant(X)
che aggiunge un colonne di quelli a X
per ottenere un termine di intercetta.
Summary of Regression Results
=======================================
| Dependent Variable: ['y']|
| Model: OLS|
| Method: Least Squares|
| Date: Sat, 28 Sep 2013|
| Time: 09:22:59|
| # obs: 100.0|
| Df residuals: 98.0|
| Df model: 1.0|
==============================================================================
| coefficient std. error t-statistic prob. |
------------------------------------------------------------------------------
| x1 1.007 0.008466 118.9032 0.0000 |
| const 0.05165 0.005138 10.0515 0.0000 |
==============================================================================
| Models stats Residual stats |
------------------------------------------------------------------------------
| R-squared: 0.9931 Durbin-Watson: 1.484 |
| Adjusted R-squared: 0.9930 Omnibus: 12.16 |
| F-statistic: 1.414e+04 Prob(Omnibus): 0.002294 |
| Prob (F-statistic): 9.137e-108 JB: 0.6818 |
| Log likelihood: 223.8 Prob(JB): 0.7111 |
| AIC criterion: -443.7 Skew: -0.2064 |
| BIC criterion: -438.5 Kurtosis: 2.048 |
------------------------------------------------------------------------------
Un altro modo per farlo, usando axes.get_xlim()
:
import matplotlib.pyplot as plt
import numpy as np
def scatter_plot_with_correlation_line(x, y, graph_filepath):
'''
http://stackoverflow.com/a/34571821/395857
x does not have to be ordered.
'''
# Scatter plot
plt.scatter(x, y)
# Add correlation line
axes = plt.gca()
m, b = np.polyfit(x, y, 1)
X_plot = np.linspace(axes.get_xlim()[0],axes.get_xlim()[1],100)
plt.plot(X_plot, m*X_plot + b, '-')
# Save figure
plt.savefig(graph_filepath, dpi=300, format='png', bbox_inches='tight')
def main():
# Data
x = np.random.rand(100)
y = x + np.random.rand(100)*0.1
# Plot
scatter_plot_with_correlation_line(x, y, 'scatter_plot.png')
if __name__ == "__main__":
main()
#cProfile.run('main()') # if you want to do some profiling
versione ad una linea di this excellent answer per tracciare la linea della misura migliore è:
plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)))
Utilizzo di np.unique(x)
anziché x
gestisce il caso in cui x
non è ordinato o ha valori duplicati.
La chiamata a poly1d
è un'alternativa alla scrittura di m*x + b
come in this other excellent answer.
Ciao, i miei valori xey sono array convertiti da liste usando 'numpy.asarray'. Quando aggiungo questa riga di codice, ottengo diverse righe sulla mia trama di dispersione invece di una. quale potrebbe essere la ragione? – artre
@artre Grazie per averlo presentato. Questo può accadere se 'x' non è ordinato o ha valori duplicati. Ho modificato la risposta. –
plt.plot(X_plot, X_plot*results.params[0] + results.params[1])
contro
plt.plot(X_plot, X_plot*results.params[1] + results.params[0])
- 1. Come sovrapporre una linea per un oggetto lm su un grafico a dispersione ggplot2
- 2. R: Come sovrapporre i grafici a torta su "punti" in un grafico a dispersione in R
- 3. linea + grafico a dispersione in nvd3
- 4. Come disegnare una griglia su un grafico in Python?
- 5. Come creare un grafico a dispersione 3D in Python?
- 6. Come animare un grafico a dispersione?
- 7. aggiungi una linea di regressione logaritmica a un grafico a dispersione (confronto con Excel)
- 8. Python Matplotlib sovrappone i grafici a dispersione
- 9. Come si può sovrapporre un'ellisse di dati su un diagramma a dispersione ggplot2?
- 10. Come assegnare una scala di colori a una variabile in un grafico a dispersione 3D?
- 11. Grafico a dispersione matplotlib con errore sconosciuto
- 12. Come sovrapporre un Jbutton a una JprogressBar
- 13. Grafico a dispersione con dati scalari
- 14. Come visualizzare una relazione non lineare in un grafico a dispersione
- 15. Grafico del tipo di dispersione in Dygraphs?
- 16. Come combinare il grafico a dispersione con il grafico a linee per mostrare la linea di regressione? JavaFX
- 17. Collegamento di due punti in un grafico a dispersione 3D in Python e matplotlib
- 18. Combinazione di un grafico a dispersione con trama di superficie
- 19. valore alfa di controllo su grafico a dispersione 3D utilizzando Python e matplotlib
- 20. Annota grafico a dispersione da un dataframe panda
- 21. Sovrapporre una tela su un div
- 22. Riga di regressione lineare nel grafico a dispersione MATLAB
- 23. Come etichettare i punti su un grafico a dispersione con R?
- 24. Posso disegnare una linea di regressione e mostrare i parametri utilizzando il grafico a dispersione con un dataframe panda?
- 25. Excel: come posso creare un grafico a dispersione con i colori di una terza colonna?
- 26. Come assegnare un colore a ogni classe nel grafico a dispersione in R?
- 27. ggplot2: C'è un modo per sovrapporre un singolo grafico a tutti gli aspetti in un ggplot
- 28. Come applicare contemporaneamente colore/forma/dimensione in un grafico a dispersione usando trama?
- 29. tracciamento di due vettori di dati su un grafico a dispersione GGPLOT2 utilizzando R
- 30. Come posso sovrapporre un controllo su una finestra?
mia figura appare diversa; la linea è nel posto sbagliato; sopra i punti – David
@David: gli array params sono nel verso sbagliato. Prova: plt.plot (X_plot, X_plot * results.params [1] + results.params [0]). O, ancora meglio: plt.plot (X, results.fittedvalues) come la prima formula assume y è lineare è x che mentre vero qui, non è sempre il caso. – Ian