Nichtlineare Regression

Die lineare Regression liefert bei vielen Fragestellungen teils verblüffend gute Modelle, anhand denen sich Zusammenhänge und Prognosen erklären bzw. erstellen lassen.

Die lineare Regression hat aber auch Ihre Grenzen. Oft reicht es nicht aus, mit einer Geraden zu arbeiten, um gute sinnvolle Ergebnisse zu bekommen.

Wir bleiben bei unserem Beispiel beim gapminder Datensatz.

Es folgen Länder, bei denen die Lebenserwartung sehr schlecht mit einer Einfachen Linearen Regression erklärt werden kann (Bitte klicken Sie auf das Bild für eine höhere Auflösung):

Bei den oben genannten Ländern sind die Ergebnisse der linearen Regression nicht brauchbar für sinnvolle Prognosen. Das Modell erklärt die Zusammenhänge nicht gut genug.

Anbei folgt dasselbe Beispiel bzw. der selbe Datensatz, allerdings wurde hier eine nichtlineare Regression angewandt, mit allen zur Verfügung stehenden Datensätzen bzw. erklärenden Variablen. Die Ergebnisse finden Sie in der folgenden Graphik:

Das Modell ist keine gerade Linie mehr. Die Lebenserwartung ist bei diesen Ländern mit einer nichtlinearen Regression besser erklärt.

Im nächsten Abschnitt erstellen wir eine DecisionTreeRegressor. Dies ist ein sehr mächtiges Model welches sehr gut nichtlineare Zusammenhänge darstellen kann. Wir versuchen mit diesem Model wieder die Getriebepreise mit Hilfe deren Gewicht vorherzusagen. Dies war eine Problmestellung die wir bereits unter der Seite lineare Regression untersucht haben, mit eher schlechten Ergebnis.

Als erstes wir der TreeRegressor initialisiert:

from sklearn.tree import DecisionTreeRegressor
tree_reg = DecisionTreeRegressor()
tree_reg.fit(X,y)
y_pred = tree_reg.predict(prediction_space)

Als nächstes wird der Score wiedergegeben:

print(tree_reg.score(X, y))                                                                                                 
0.8014740447746284

Das Ergebnis ist schon sehr brauchbar!

Ein Graphische Ansicht des Ergebnis wird folgendermaßen erstelllt:

data.plot(kind="scatter", x="Gewicht", y="Preis Währungsumrechnung", alpha=0.1)
plt.plot(prediction_space,y_pred, color='black', linewidth=3)
plt.show()