Entscheidungsmodell
Trainiertes Entscheidungmodell zwischen Störmeldungen und Großwetterlagen
Methode¶
- python library “scikit-learn”
- non-parametric supervised learning method
- The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features
Leseprobe¶
Die Kritikalität von Wettermustern auf Schadensmuster werden durch einen Machine-Learning Algorithmus “decision tree” realsiert. Dieser lernt aus Trainingsdaten und leitet daraus Regeln für die Anwendung auf Testdaten ab.
Die Zielgröße sind tägliche Störmeldungen in den 34 PDs der Bahn von 2015-2019. Die Unabhängigen Variablen sind die Monate und die Großwetterlagen. Diese gelten für alle PDs gleichermaßen. Weitere meteorologische Größen sind optional.
Die sortierte Jahressumme der Meldungen je PD dient der Orientiertung. Der Anteil der Meldungen, welche durch Großwetterlagen erklärt werden können ist kleiner, aber folgt dem tatsächlichen Verlauf.
Mittels Testdaten, d.h. Daten zu Großwetterlagen aus der Vergangenheit oder aus Klimamodellen, lassen sich Aussagen über Entwicklungen ableiten.
Ergebnis¶
Code¶
Python
import sys
import os
import matplotlib
matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap
from mpl_toolkits.basemap import shiftgrid
from matplotlib.path import Path
from matplotlib.patches import PathPatch
import numpy
import shapefile
import numpy as N
import scipy.ndimage
from scipy.ndimage.filters import gaussian_filter
import matplotlib.mlab as ml
from scipy.interpolate import griddata
from scipy import interpolate
import math
from scipy import signal,stats
from scipy.ndimage.filters import uniform_filter
import string
from matplotlib.collections import PatchCollection
from pylab import Polygon
import matplotlib as mpl
import pandas as pd
from sklearn import tree
from sklearn import metrics
from sklearn.externals.six import StringIO
from sklearn.metrics import accuracy_score
import pydotplus # To create our Decision Tree Graph
from IPython.display import Image # To Display a image of our graph
plt.style.use('bmh')
params = {'legend.fontsize': 8,'font.family': 'serif'}
plt.rcParams.update(params)
def deg2num(lat_deg, lon_deg, zoom):
lat_rad = math.radians(lat_deg)
n = 2.0 ** zoom
xtile = int((lon_deg + 180.0) / 360.0 * n)
ytile = int((1.0 - math.log(math.tan(lat_rad) + (1 / math.cos(lat_rad))) / math.pi) / 2.0 * n)
return (xtile, ytile)
def moving_summe(a,n=3):
ret=N.cumsum(a,dtype=float)
ret[n:]=ret[n:]-ret[:-n]
return ret[n-1:]
def moving_average(a,n=3):
ret=N.cumsum(a,dtype=float)
ret[n:]=ret[n:]-ret[:-n]
return ret[n-1:]/n
file = '../dat/gwlneudatum.dat'
jj=N.genfromtxt(file,usecols=(2),skip_header=0,dtype="I")
mm=N.genfromtxt(file,usecols=(1),skip_header=0,dtype="I")
dd=N.genfromtxt(file,usecols=(0),skip_header=0,dtype="I")
gw=N.genfromtxt(file,usecols=(3),skip_header=0,dtype=str)
ind = N.where((jj>=2015)&(jj<=2019))[0]
gwls = gw[ind]
nd = len(ind)
gcms = ['MPI-ESM-LR','MIROC5','HadGEM2-ES','CNRM-CM5','IPSL-CM5A-MR','NorESM1-M','EC-EARTH','CanESM2']
gcms = ['MPI-ESM-LR','MIROC5','CNRM-CM5','IPSL-CM5A-MR','NorESM1-M','EC-EARTH']
gcms = ['MPI-ESM-LR','MIROC5','CNRM-CM5','IPSL-CM5A-MR','NorESM1-M']
ng = len(gcms)
peri = [1961,1981,2000,2021,2041,2061]
np = len(peri)
krit = N.zeros((ng,34,np,2),float);krit[:,:,:,:] = N.nan
filter = 'Filter1'
m = Basemap(projection='merc',llcrnrlat=47,urcrnrlat=55.2,llcrnrlon=5,urcrnrlon=16,resolution='l')
m.readshapefile('../shp/OSS_M1_PROD_DURCH_FL_EXT_geo', 'pds',zorder=10)
obs1 = []
sim1 = []
obs2 = []
sim2 = []
nam = []
xs = []
ys = []
zs = []
k = -1
for info, shape in zip(m.pds_info, m.pds):
k = k+1
p = info['NAME']
nam.append(p.decode('utf-8'))
file = '../csv/pds-uas/'+p+'.csv'
data=N.genfromtxt(file,names=True,comments='#',delimiter=';',dtype=None,encoding='utf-8')
id = N.where((data['ja']>=2015)&(data['ja']<=2019))[0]
bahn1 = data['Filter1'][id]
bahn2 = data['Filter2'][id]
file = '../csv/obs-dwd/'+p+'.csv'
data=N.genfromtxt(file,names=True,comments='#',delimiter=';',dtype=None,encoding='utf-8')
id = N.where((data['ja']>=2015)&(data['ja']<=2019))[0]
mona = data['mo'][id]
tmax = data['tmax'][id]
nied = data['nied'][id]
wmax = data['wmax'][id]
for i in mona:
xs.append(i)
mo = N.zeros(nd,str)
mo[mona==1] = 'Jan'
mo[mona==2] = 'Feb'
mo[mona==3] = 'Mar'
mo[mona==4] = 'Apr'
mo[mona==5] = 'May'
mo[mona==6] = 'Jun'
mo[mona==7] = 'Jul'
mo[mona==8] = 'Aug'
mo[mona==9] = 'Sep'
mo[mona==10] = 'Okt'
mo[mona==11] = 'Nov'
mo[mona==12] = 'Dec'
tx = N.zeros(nd,str)
tx[tmax>=30] = 'tmax30'
tx[tmax<30] = 'normal'
pr = N.zeros(nd,str)
pr[nied>=20] = 'nied20'
pr[nied<20] = 'normal'
wx = N.zeros(nd,str)
wx[wmax>=25] = 'wmax25'
wx[wmax<25] = 'normal'
gw = N.array(gwls,str)
ua1 = bahn1
ua2 = bahn2
df = pd.DataFrame()
df['mona'] = mo
df['wmax'] = wx
df['tmax'] = tx
df['nied'] = pr
df['gwls'] = gw
df['bahn1'] = ua1
df['bahn2'] = ua2
obs1.append(N.sum(ua1))
obs2.append(N.sum(ua2))
for i in ua1:
ys.append(i)
one_hot_data = pd.get_dummies(df[ ['gwls','mona'] ])
clf = tree.DecisionTreeClassifier(criterion="gini")
clf_train = clf.fit(one_hot_data,df['bahn1'])
pred = clf_train.predict(one_hot_data)
sim1.append(N.sum(pred))
for i in pred:
zs.append(i)
clf_train = clf.fit(one_hot_data,df['bahn2'])
pred = clf_train.predict(one_hot_data)
sim2.append(N.sum(pred))
ig = -1
for gcm in gcms:
ig = ig+1
if(gcm!='HESS-BREZ'):
file = '../dat/'+gcm+'.dat'
tt=N.genfromtxt(file,usecols=(0),skip_header=0,delimiter=';',dtype=str)
gg=N.genfromtxt(file,usecols=(3),skip_header=0,delimiter=';',dtype=str)
nt = len(tt)
mm = []
jj = []
for t in range(nt):
jj.append(int(tt[t][0:4]))
mm.append(int(tt[t][5:7]))
mm = N.array(mm)
jj = N.array(jj)
else:
file = '../dat/gwlneudatum.dat'
mm=N.genfromtxt(file,usecols=(1),skip_header=0,dtype='i')
jj=N.genfromtxt(file,usecols=(2),skip_header=0,dtype='i')
gg=N.genfromtxt(file,usecols=(3),skip_header=0,dtype=str)
if((gcm=='HESS-BREZ')|(gcm=='NCEP-NCAR')): np = 3
else: np = len(peri)
for ip in range(np):
ja = peri[ip]
je = ja+19
id = N.where((jj>=ja)&(jj<=je))[0]
gx = gg[id]
mx = mm[id]
jx = jj[id]
nt = len(id)
mo = N.zeros(nt,str)
mo[mx==1] = 'Jan'
mo[mx==2] = 'Feb'
mo[mx==3] = 'Mar'
mo[mx==4] = 'Apr'
mo[mx==5] = 'May'
mo[mx==6] = 'Jun'
mo[mx==7] = 'Jul'
mo[mx==8] = 'Aug'
mo[mx==9] = 'Sep'
mo[mx==10] = 'Okt'
mo[mx==11] = 'Nov'
mo[mx==12] = 'Dec'
df0 = pd.DataFrame()
df0['mona'] = mo
df0['gwls'] = gx
clf_train = clf.fit(one_hot_data,df['bahn1'])
one_hot_data1 = pd.get_dummies(df0[ ['gwls','mona'] ])
pred = clf_train.predict(one_hot_data1)
krit[ig,k,ip,0] = N.sum(pred)
clf_train = clf.fit(one_hot_data,df['bahn2'])
pred = clf_train.predict(one_hot_data1)
krit[ig,k,ip,1] = N.sum(pred)
obs1 = N.array(obs1)
sim1 = N.array(sim1)
obs2 = N.array(obs2)
sim2 = N.array(sim2)
nam = N.array(nam)
xs = N.array(xs)
ys = N.array(ys)
zs = N.array(zs)
fig = plt.figure(figsize=(12,10))#,frameon=False)
plt.subplot(221)
pds = N.arange(34)
id = N.argsort(obs1)
plt.plot(obs1[id]/5.,pds,'k',label='obs')
plt.plot(sim1[id]/5.,pds,'b',label='sim')
i = -1
for gcm in gcms:
i = i+1
plt.plot(krit[i,id,0,0]/20.,pds,'g',label=gcm,lw=0.5)
plt.plot(krit[i,id,1,0]/20.,pds,'orange',lw=0.5)
plt.plot(krit[i,id,2,0]/20.,pds,'red',lw=0.5)
plt.title('Entscheidungsmodell: Meldungen ~ P(Monat,GWL)',fontsize=12,weight='bold')
plt.xlabel('Meldungen (Filter1)',fontsize=14,weight='bold')
plt.yticks(pds,nam[id],fontsize=10)
plt.ylim(-1,34)
plt.legend(loc=4,shadow='true')
plt.subplot(222)
id = N.argsort(obs2)
plt.plot(obs2[id]/5.,pds,'k',label='obs')
plt.plot(sim2[id]/5.,pds,'b',label='sim')
i = -1
for gcm in gcms:
i = i+1
plt.plot(krit[i,id,0,1]/20.,pds,'g',label=gcm,lw=0.5)
plt.plot(krit[i,id,1,1]/20.,pds,'orange',lw=0.5)
plt.plot(krit[i,id,2,1]/20.,pds,'red',lw=0.5)
plt.title('Entscheidungsmodell: Meldungen ~ P(Monat,GWL)',fontsize=12,weight='bold')
plt.xlabel('Meldungen (Filter2)',fontsize=14,weight='bold')
plt.yticks(pds,nam[id],fontsize=10)
plt.ylim(-1,34)
plt.legend(loc=4,shadow='true')
plt.subplot(223)
i = -1
for gcm in gcms:
i = i+1
tmp = N.sum(krit[i,:,:,0],0)/30.-N.sum(krit[i,:,0,0],0)/30.
if((gcm=='HESS-BREZ')|(gcm=='NCEP-NCAR')): x = 2
else: x=len(peri)-1
tmp = N.cumsum(tmp-tmp[2])
plt.plot(peri,tmp,'r-o',label=gcm,lw=0.5)
plt.text(peri[x]+5,tmp[x],gcm,va='center')
plt.ylabel('Meldungen (Filter1)',fontsize=14,weight='bold')
plt.xticks(peri)
plt.xlim(1951,2100)
plt.ylim(-3000,3000)
plt.subplot(224)
i = -1
for gcm in gcms:
i = i+1
tmp = N.sum(krit[i,:,:,1],0)/30.-N.sum(krit[i,:,0,1],0)/30.
if((gcm=='HESS-BREZ')|(gcm=='NCEP-NCAR')): x=2
else: x = len(peri)-1
tmp = N.cumsum(tmp-tmp[2])
plt.plot(peri,tmp,'r-o',label=gcm,lw=0.5)
plt.text(peri[x]+5,tmp[x],gcm,va='center')
plt.ylabel('Meldungen (Filter2)',fontsize=14,weight='bold')
plt.xticks(peri)
plt.xlim(1951,2100)
plt.ylim(-200,200)
plt.tight_layout()
plt.savefig('./img/decision.png',dpi=240,bbox_inches='tight',pad_inches=0.0)#,transparent='true')
fig = plt.figure(figsize=(12,5))#,frameon=False)
for k in [1,2]:
plt.subplot(1,2,k)
for sw in [0,20]:#,40]:
if(k==1): yy = ys
if(k==2): yy = zs
id = N.where(yy>sw)[0]
xx = xs[id]
yy = yy[id]
values = N.vstack([xx,yy])
print (values.shape)
kernel = stats.gaussian_kde(values)
X, Y = N.mgrid[0:13:100j,0:100:200j]
positions = N.vstack([X.ravel(),Y.ravel()])
Z2 = N.reshape(kernel(positions).T, X.shape)
lev = N.arange(0.5,10,0.5)
CF=plt.contourf(X,Y,1000*Z2,levels=lev,cmap=plt.cm.YlOrRd,extend='max',zorder=1)
plt.contour(X,Y,1000*Z2,levels=lev,colors='k',linewidths=0.5,zorder=1)
plt.text(2,10,'N>0',size=14,weight='bold',ha='center',va='center',bbox=dict(boxstyle="circle", ec='k', fc='gray',alpha=0.5,lw=2))
plt.text(2,30,'N>20',size=14,weight='bold',ha='center',va='center',bbox=dict(boxstyle="circle", ec='k', fc='gray',alpha=0.5,lw=2))
#plt.text(2,50,'N>40',size=14,weight='bold',ha='center',va='center',bbox=dict(boxstyle="circle", ec='k', fc='gray',alpha=0.5,lw=2))
plt.tick_params(direction='out')
plt.xlim(0.5,12.5)
plt.xticks([1,2,3,4,5,6,7,8,9,10,11,12],['Jan','Feb','Mar','Apr','Mai','Jun','Jul','Aug','Sep','Okt','Nov','Dez'])
plt.ylim(0,50)
plt.grid(color='k')
plt.ylabel('Anzahl',fontsize=16,weight='bold')
if(k==1): plt.title('Bahn ('+filter+') ~ Monat: Beobachtung',fontsize=14,weight='bold')
if(k==2): plt.title('Bahn ('+filter+') ~ Monat: Simulation',fontsize=14,weight='bold')
plt.savefig('./img/season.png',dpi=240,bbox_inches='tight',pad_inches=0.0)