2011-11-15 17 views
8

Sto cercando di usare PyBrain per un semplice allenamento NN. Quello che non so come fare è caricare i dati di allenamento da un file. Non è spiegato nel loro sito web da nessuna parte. Non mi interessa il formato perché riesco a costruirlo ora, ma ho bisogno di farlo in un file invece di aggiungere riga per riga manualmente, perché avrò diverse centinaia di righe.Come caricare i dati di allenamento in PyBrain?

+1

Diverse centinaia di righe significa che hai un set molto piccolo e non dovresti preoccuparti delle prestazioni. Ma PyBrain non accetta solo gli array NumPy? –

+0

Non lo so, sto solo iniziando a usarlo, ma da nessuna parte si dice come usare gli array NumPy con il loro NN:/ –

risposta

21

Ecco come ho fatto:

 
ds = SupervisedDataSet(6,3) 

tf = open('mycsvfile.csv','r') 

for line in tf.readlines(): 
    data = [float(x) for x in line.strip().split(',') if x != ''] 
    indata = tuple(data[:6]) 
    outdata = tuple(data[6:]) 
    ds.addSample(indata,outdata) 

n = buildNetwork(ds.indim,8,8,ds.outdim,recurrent=True) 
t = BackpropTrainer(n,learningrate=0.01,momentum=0.5,verbose=True) 
t.trainOnDataset(ds,1000) 
t.testOnData(verbose=True) 

In questo caso la rete neurale ha 6 ingressi e 3 uscite. Il file csv ha 9 valori su ogni riga separati da una virgola. I primi 6 valori sono valori di input e gli ultimi tre sono output.

+0

che è fantastico, grazie mille. Sai come posso accedere ai valori di peso per ogni neurone? –

+1

È possibile accedere ai singoli livelli in questo modo: n ['in'] per il livello di input e n ['out'] per l'output o n ['hidden0'] per il primo livello nascosto. Non lo so, ma suppongo che tu possa accedere ai nodi del livello in qualche modo. dir (n ['in']) dovrebbe darti un suggerimento su cosa puoi fare – c0m4

+0

Non riesco a trovare come farlo. Farò una nuova domanda Grazie per l'aiuto. –

1

basta usare array panda in questo modo

import pandas as pd 

ds = SupervisedDataSet(6,3) 

dataset = pd.read_csv('mycsvfile.csv','r', delimiter=',',skiprows=1) 
ds.setfield('input' dataset.values[:,0:6]) 
ds.setfield('target', dataset.values[:,-2:-1]) 

e siete a posto.

Problemi correlati