2015-08-23 10 views
8

Ho recentemente iniziato la classe AI a Coursera e ho una domanda relativa alla mia implementazione dell'algoritmo di discesa del gradiente.Gradiente discesa in Java

Ecco il mio attuale implementazione (in realtà ho solo "tradotto" le espressioni matematiche in codice Java):

public class GradientDescent { 

private static final double TOLERANCE = 1E-11; 

private double theta0; 
private double theta1; 

public double getTheta0() { 
    return theta0; 
} 

public double getTheta1() { 
    return theta1; 
} 

public GradientDescent(double theta0, double theta1) { 
    this.theta0 = theta0; 
    this.theta1 = theta1; 
} 

public double getHypothesisResult(double x){ 
    return theta0 + theta1*x; 
} 

private double getResult(double[][] trainingData, boolean enableFactor){ 
    double result = 0; 
    for (int i = 0; i < trainingData.length; i++) { 
     result = (getHypothesisResult(trainingData[i][0]) - trainingData[i][1]); 
     if (enableFactor) result = result*trainingData[i][0]; 
    } 
    return result; 
} 

public void train(double learningRate, double[][] trainingData){ 
    int iteration = 0; 
    double delta0, delta1; 
    do{ 
     iteration++; 
     System.out.println("SUBS: " + (learningRate*((double) 1/trainingData.length))*getResult(trainingData, false)); 
     double temp0 = theta0 - learningRate*(((double) 1/trainingData.length)*getResult(trainingData, false)); 
     double temp1 = theta1 - learningRate*(((double) 1/trainingData.length)*getResult(trainingData, true)); 
     delta0 = theta0-temp0; delta1 = theta1-temp1; 
     theta0 = temp0; theta1 = temp1; 
    }while((Math.abs(delta0) + Math.abs(delta1)) > TOLERANCE); 
    System.out.println(iteration); 
} 

}

il codice funziona abbastanza bene, ma solo se scelgo un poco alfa, qui chiamato learningRate. Se è superiore a 0,00001, diverge.

Avete qualche suggerimento su come ottimizzare l'implementazione, o una spiegazione per la "Alpha-Issue" e una possibile soluzione per questo?

Aggiornamento:

Ecco il principale, tra cui alcuni input di esempio:

private static final double[][] TDATA = {{200, 20000},{300, 41000},{900, 141000},{800, 41000},{400, 51000},{500, 61500}}; 

public static void main(String[] args) { 
    GradientDescent gd = new GradientDescent(0,0); 
    gd.train(0.00001, TDATA); 
    System.out.println("THETA0: " + gd.getTheta0() + " - THETA1: " + gd.getTheta1()); 
    System.out.println("PREDICTION: " + gd.getHypothesisResult(300)); 
} 

L'espressione matematica di discesa del gradiente è il seguente:

enter image description here

+3

Probabilmente dovresti fornire qualche esempio di dati/input, in un metodo 'main', e forse un backlink al forum che hai" tradotto ". – Marco13

+0

Ho aggiornato la domanda e ho riscontrato anche un piccolo problema. Dopo averlo corretto, ora sono in grado di impostare la velocità di apprendimento su 0,0001. Ma penso che sia ancora piuttosto basso ma molto meglio di prima. –

+0

Quale delle classi Coursera ML/AI era? Stanford's Machine Learning? –

risposta

1

Per risolvere questo problema, è necessario normalizzare i dati con questo formulario: (Xi-mu)/s. Xi è il valore impostato per la formazione corrente, mu la media dei valori nella colonna corrente ed s il valore massimo meno il valore minimo della colonna corrente. Questa formula otterrà i dati di allenamento approssimativamente in un intervallo compreso tra -1 e 1, che consente di scegliere tassi di apprendimento più elevati e la discesa del gradiente per convergere più velocemente. Ma è poi necessario denormalizzare il risultato previsto.

0

si dovrebbe usare Java .math.BigDecimal per le tue operazioni aritematiche.
doppio ha i suoi problemi di arrotondamento durante l'esecuzione di qualsiasi arithematic.

+4

Questa è un'ipotesi, e sono vicino al downvoting di questo: la semplice sostituzione di 'double' con' BigDecimal' non risolverà necessariamente il problema, il che potrebbe (!) Essere completamente estraneo alla precisione .... – Marco13

2
private double getResult(double[][] trainingData, boolean enableFactor){ 
double result = 0; 
for (int i = 0; i < trainingData.length; i++) { 
    result = (getHypothesisResult(trainingData[i][0]) - trainingData[i][1]); 
    if (enableFactor) result = result*trainingData[i][0]; 
} 
return result; 

In questa funzione. la variabile risultato ha sovrascritto ogni iterazione, perdendo il vecchio valore. Quando si inseriscono i valori, viene calcolato solo l'ultimo elemento dell'array. Il resto di loro non importa.

+0

Hai ragione ! Grazie per il suggerimento! Si è rivelato essere una discesa con gradiente stocastico molto lento ... ancora funzionante ma lento perché tutti i campioni verranno elaborati per aggiornare i parametri con un solo campione. –

Problemi correlati