Jump to content


Neural Network in Java..... Need HELP...:(


  • You cannot reply to this topic
No replies to this topic

#1 noise

    New Member

  • Members
  • Pip
  • 7 posts

Posted 11 July 2004 - 06:38 PM

Dear people,
I really need help on the 2-layer neural network (inputs, hidden layer, output layer) below.
I am using the standard backpropagation learning algorithm and use momentum and variable learning rate.
The activation functions at both layers are sigmoid functions.
The problems are:
*When i used the network to train the XOR function it is working fine....but when i train the SINE function it does not work
*I cannot train the network with negative values (like in the sine training dataset). When i just take the positive values then the net can learn.
*When i train the network, the momentum does nothing, the network learn as good with zero momentum... so my question is whether i used the correct algorithm for the momentum....
Anyone can me some clue....I am so mad about this net... it takes me months for nothing....:(
Thank you in advance....


/****************************************************************************
  NNet.java
 ***************************************************************************/
 
import java.util.*;
import java.io.*;
import java.lang.*;
import java.lang.Math;
 
public class NNet extends Object {
           
 public  int numInputs ;              //number of input neurons
 public  int numHidden ;                        //number of hidden neurons
 public  int numOutputs;                        //number of output neurons
 
 private  int numPatterns;                       //number of training patterns
 private  int step4display=1;
 private  int numEpochs;
 private  double goal = 0.01;
 private  double hidLearnRate, outLearnRate;                //learning rate
 private  double momentum=0.8;                                   //momentum
 private  double n=1.05;                        //increase learning rate
 private  double p=0.7;                         //decrease learning rate
 private  double L=0.04;                       //network error threshold
 
 public  double inputs[];
 private double nHidden[];        //for calculating sensitivity at input-hidden layer
 private  double hidden[];
 private double nOut[];              //for calculating sensitivity at hidden-output layer
 public  double outPred[];
 private  double errCurPat[];     //error of the current pattern
 
 private double B1[];                 //bias at hidden layer
 private double B2[];                 //bias at output layer
 private double B1Change[];
 private double B2Change[];
 
            //training data
 public  double trainInputs[][];
 public  double trainOutputs[][];
 
 private  double W1[][]; //weights at the input-hidden layer
 private  double W2[][]; //weights at the hidden-output layer
 private  double W1Change[][];
 private  double W2Change[][];
 
 private  int curPat;                    //current pattern number, used for in training
 public  double RMSerror;        //total error
 
 private  double s1[]; //sensitivity at the input-hidden layer
 private  double s2[];     //sensitivity at the hidden-output layer
 
 
 
public NNet(int i, int h, int o, double weightRandomRange) {
            //System.out.println("enter NNet() constructor...");
            numInputs  = i; //number of inputs
            numHidden  = h;
            numOutputs = o;
           
            W1  = new double[numInputs][numHidden];
            W2  = new double[numHidden][numOutputs];
           
            W1Change  = new double[numInputs][numHidden];
            W2Change  = new double[numHidden][numOutputs];
           
            inputs  = new double[numInputs];         
            hidden  = new double[numHidden];      
            outPred  = new double[numOutputs];
            errCurPat= new double[numOutputs];
           
            B1 = new double[numHidden];
            B2 = new double[numOutputs];
           
            B1Change = new double[numHidden];
            B2Change = new double[numOutputs];
           
            nHidden  = new double[numHidden];
            nOut  = new double[numOutputs];
           
            s1 = new double[numHidden];
            s2 = new double[numOutputs];
           
            initWeights(weightRandomRange);       
           
            for (int k=0;k<numHidden;k++)
                        B1[k] = 1;
           
            for (int k=0;k<numOutputs;k++)
                        B1[k] = 1;       
}
 
            //this method assumes inputs[] is not empty
public  void feedForward(){
           
    for(int h = 0;h<numHidden;h++){
                        hidden[h] = 0.0;
                        nHidden[h] = 0.0;
        for(int i = 0;i<numInputs;i++)
                    nHidden[h] += inputs[i] * W1[i][h];         
                hidden[h] = L1_Function(nHidden[h] + B1[h]);
            }
 
    for(int o = 0;o<numOutputs;o++){
                        outPred[o] = 0.0;
                        nOut[o] = 0.0;
                for(int h = 0;h<numHidden;h++)
                    nOut[o] += hidden[h] * W2[h][o];
                outPred[o] = L2_Function(nOut[o] + B2[o]);
            }
 }
 
public  void feedForward(double ins[]){
 
            inputs=ins;
           
    for(int h = 0;h<numHidden;h++){
                        hidden[h] = 0.0;
                        nHidden[h] = 0.0;
        for(int i = 0;i<numInputs;i++)
                    nHidden[h] += inputs[i] * W1[i][h];         
                hidden[h] = L1_Function(nHidden[h] + B1[h]);
            }
 
    for(int o = 0;o<numOutputs;o++){
                        outPred[o] = 0.0;
                        nOut[o] = 0.0;
                for(int h = 0;h<numHidden;h++)
                    nOut[o] += hidden[h] * W2[h][o];
                outPred[o] = L2_Function(nOut[o] + B2[o]);
            }
 }
 
 
 public void setTrainData(double trainIns[][], double trainTargets[][]){
                       
            if (trainIns.length == trainTargets.length) {
                        numPatterns = trainIns.length;
                       
                        trainInputs = new double[numPatterns][numInputs];
                        trainOutputs = new double[numPatterns][numOutputs];
                       
                        for (int i=0; i<trainIns.length; i++) {
                                    for (int j=0; j<numInputs; j++)
                                                trainInputs[i][j] = trainIns[i][j];
                        }
                        for (int i=0; i<trainTargets.length; i++) {
                                    for (int j=0; j<numOutputs; j++)
                                                trainOutputs[i][j] = trainTargets[i][j];
                        }                                 
            }
            else {
                        System.out.print("Error... The number of training inputs is not the ");
                        System.out.println("same as that of training targets");
            }
            System.out.println("leave setTrainData");
 }
 
 public void setTrainDataFromFile(String filename,String delimiter, int numInputs, int numTargets, int numRows) {
           
            String record =  new String();
            double temp;
            int row = 0;
            StringTokenizer token;
           
            trainInputs = new double[numRows][numInputs];
            trainOutputs = new double[numRows][numOutputs];
           
            try {
                        FileReader                    fr = new FileReader(filename);
                        BufferedReader            br = new BufferedReader(fr);
                       
                        while ((record = br.readLine()) != null) {
                                    token = new StringTokenizer(record,delimiter);
                                   
                                    if (token.countTokens()!=numInputs+numTargets)
                                                break;
                                    else
                                                for(int column=0;column<numInputs+numTargets;column++){
                                                            temp = Double.valueOf(token.nextToken(delimiter)).doubleValue();
                                                            if (column<numInputs)  {
                                                                        trainInputs[row][column] = temp;
                                                                        System.out.print(temp + " ");
                                                            }                                            
                                                            else{
                                                                        trainOutputs[row][column-numInputs] = temp;
                                                                        System.out.print(temp + " ");
                                                            }
                                                }
                                                System.out.println();
                                    row+=1;          
                                    if (row>numRows)
                                                break;
                        }
                        this.numPatterns = row;
 
            }
            catch (IOException e) {
           // catch possible io errors from readLine()
           System.out.println("Uh oh, got an IOException error!");
           e.printStackTrace();
    }
           
 }
 
 public void train(double goal, double hidLearnRate, double outLearnRate, double momentum, int numEpochs, int step) {
            this.goal = goal;
            this.numEpochs = numEpochs;
            this.step4display = step;
    for(int j = 0;j <= numEpochs;j++) {  
      if (RMSerror>goal || j==0) {
            trainOneEpoch(hidLearnRate,outLearnRate, momentum);
            if (j % step4display  == 0 || j==numEpochs)        
            System.out.println("epoch = " + j + "  RMS Error = " + ((double)(Math.round(RMSerror*100000))/1000) + "%");
      }
           
    }     
 }
 
 private void trainOneEpoch(double hidLearnRate,double outLearnRate, double momentum) {
           
            RMSerror = 0.0;
            this.momentum = momentum;
            double temp=0, tempRMSerror=0, m = momentum;
            double tempW1Change[][] = new double[numInputs][numHidden];
            double tempW2Change[][] = new double[numHidden][numOutputs];
            double tempB1Change[] = new double[numHidden];
            double tempB2Change[] = new double[numOutputs];
           
 
            for (int o=0;o<numOutputs;o++)
              for(int h = 0;h<numHidden;h++){
                        tempW2Change[h][o] =0 ;
              }
                       
            for (int i = 0;i<numInputs;i++)
              for(int h = 0;h<numHidden;h++){
                        tempW1Change[i][h] = 0;
              }
                       
            for (int o=0;o<numOutputs;o++){
                        tempB2Change[o] = 0;
            }
 
            for (int h=0;h<numHidden;h++){
                        tempB1Change[h] = 0 ;
            }         
               
                        //batch processing:
            for (int j=0; j< numPatterns; j++) {
                       
                        curPat = j;
                        feedForward(trainInputs[curPat]);
                        CalcCurPatErr();
                       
                        //Be careful the order when updating weights!   
                CalcW2Change(outLearnRate);
                CalcB2Change(outLearnRate);
                CalcW1Change(hidLearnRate);       
                CalcB1Change(hidLearnRate);
 
                        for (int o=0;o<numOutputs;o++)
                          for(int h = 0;h<numHidden;h++){
                                    W2Change[h][o] = m * W2Change[h][o] - (1-m)*outLearnRate*s2[o]*hidden[h] ;     
                          }
                       
                        for (int i = 0;i<numInputs;i++)
                          for(int h = 0;h<numHidden;h++){
                                    W1Change[i][h] = m * W1Change[i][h] - (1-m)*hidLearnRate*s1[h]*inputs[i];
                          }
                                   
 
                        for (int o=0;o<numOutputs;o++){
                                    B2Change[o] = m * B2Change[o] - (1-m)*outLearnRate*s2[o] ;        
                        }
                                   
                        for (int h=0;h<numHidden;h++){
                                    B1Change[h] = m * B1Change[h] - (1-m)*outLearnRate*s1[h] ;
                        }
                                   
                            
                        for (int o=0;o<numOutputs;o++)
                                    for(int h = 0;h<numHidden;h++)
                                    tempW2Change[h][o] += W2Change[h][o];                             
                                               
                        for (int i = 0;i<numInputs;i++)
                                    for(int h = 0;h<numHidden;h++)
                                                tempW1Change[i][h] += W1Change[i][h];                               
                       
                        for (int o=0;o<numOutputs;o++)
                                    tempB2Change[o] += B2Change[o];
           
                        for (int h=0;h<numHidden;h++)
                                    tempB1Change[h] = B1Change[h];
                       
            RMSerror += CalcRMSerr();
            }
 
                        //calculate the everage weight changes
            for (int o=0;o<numOutputs;o++)
                        for(int h = 0;h<numHidden;h++)
                        tempW2Change[h][o] = tempW2Change[h][o]/(numPatterns);                          
                                   
            for (int i = 0;i<numInputs;i++)
                        for(int h = 0;h<numHidden;h++)
                                    tempW1Change[i][h]= tempW1Change[i][h]/(numPatterns);                             
 
            for (int o=0;o<numOutputs;o++)
                        tempB2Change[o] = tempB2Change[o]/(numPatterns);
 
            for (int h=0;h<numHidden;h++)
                        tempB1Change[h] = tempB1Change[h]/(numPatterns);
           
                        //calculate the everage network error
            RMSerror = RMSerror/numPatterns;
           
                        //update weights                       
            for (int o=0;o<numOutputs;o++)
                        for(int h = 0;h<numHidden;h++)
                        W2[h][o] += tempW2Change[h][o];
                                   
            for (int i = 0;i<numInputs;i++)
                        for(int h = 0;h<numHidden;h++)
                                    W1[i][h] += tempW1Change[i][h];
                       
                        //update bias
            for (int o=0;o<numOutputs;o++)
                        B2Change[o] += tempB2Change[o];   
 
            for (int h=0;h<numHidden;h++)
                        B1Change[h] += tempB1Change[h];
           
                        //calculate network error after updating weights
            for (int j=0; j< numPatterns; j++) {
                        curPat = j;
                        feedForward(trainInputs[curPat]);                    
                        CalcCurPatErr();         
                        tempRMSerror += CalcRMSerr();
            }
            tempRMSerror = tempRMSerror/numPatterns;
            double errChangeRate = (tempRMSerror-RMSerror)/RMSerror;
           
           
            if (errChangeRate > L) {
                                    //discard weight update, decrease learning rate and set momentum to zero
                        for (int o=0;o<numOutputs;o++)
                                    for(int h = 0;h<numHidden;h++)
                                    W2[h][o] -= tempW2Change[h][o];
                                               
                        for (int i = 0;i<numInputs;i++)
                                    for(int h = 0;h<numHidden;h++)
                                                W1[i][h] -= tempW1Change[i][h];
                       
                        outLearnRate *= p;
                hidLearnRate *= p;
               
                m = 0;
            }         
                                    //accept weight update, but change learning rate and restore momentum
            if (errChangeRate < 0) {
                        outLearnRate *= n;
                hidLearnRate *= n;
                m = momentum;
            }
                                    //accept weight update, learning rate and moment unchange
            if (0 < errChangeRate && errChangeRate <= L) {
                        m = momentum;
            }         
 }
 
            //calculate the error of the curren pattern
 private void CalcCurPatErr() {
            for(int o = 0;o<numOutputs;o++){
                        errCurPat[o] = (outPred[o] - trainOutputs[curPat][o]);
            }
 }
 
 private double CalcRMSerr() {
            double temp=0;
    for (int j =0; j< numOutputs; j++)
            temp += errCurPat[j] * errCurPat[j];       
            temp = Math.sqrt(temp);
            return temp;
 }
 
 
private void CalcS2(){
            for (int o=0;o<numOutputs;o++){
                        s2[o] = 2 * L2_Function_D1(nOut[o]) * errCurPat[o];
            }
 }
 
private  void CalcS1(){            
            for(int h = 0;h<numHidden;h++)
                        for(int o = 0;o<numOutputs;o++)
                                    s1[h] = L1_Function_D1(nHidden[h]) * W2[h][o] * s2[o];      
                       
 }
 
private void CalcW2Change(double outLearnRate){
            this.outLearnRate = outLearnRate;
            CalcS2();
            for (int o=0;o<numOutputs;o++){
                        for(int h = 0;h<numHidden;h++)
                        W2Change[h][o] = -outLearnRate * s2[o] * hidden[h];                        
            }
 }
 
 private  void CalcW1Change(double hidLearnRate){
            this.hidLearnRate = hidLearnRate;
            CalcS1();
            for (int i = 0;i<numInputs;i++)
                        for(int h = 0;h<numHidden;h++)
                                    W1Change[i][h] = -hidLearnRate * s1[h] * trainInputs[curPat][i];
 }
 
 private void CalcB2Change(double outLearnRate) {
            for(int o = 0;o<numOutputs;o++)
            B2Change[o] = -outLearnRate * s2[o];           
 }
 
 private void CalcB1Change(double hidLearnRate) {
            for(int h = 0;h<numHidden;h++)
            B1Change[h] = -hidLearnRate * s1[h];            
 }
 
 public void initWeights(double weightRandomRange){
            //System.out.println("enter initWeights()...");
           
            for(int i = 0;i<numInputs;i++)
            for(int h = 0;h<numHidden;h++)
                        W1[i][h] = (Math.random()*2*weightRandomRange-weightRandomRange);
                       
            for(int h = 0;h<numHidden;h++)
                        for(int o = 0;o<numOutputs;o++)
                                    W2[h][o] = (Math.random()*2*weightRandomRange-weightRandomRange);
            //System.out.println("leaving initWeights()...");
 }
 
 private  double L1_Function(double x){
       return Sigmoid(x);    
 }
 
 private double L1_Function_D1(double x) {
    double z = Sigmoid(x);
            return (z * (1 - z) + 0.1);
  }
 
 
 private  double L2_Function(double x){
    return Sigmoid(x);
    //return purelin(x);
 }
 
 private double L2_Function_D1(double x) {
    double z = Sigmoid(x);
    return (z * (1 - z) + 0.1);
    //return 1;
  }
 
 private  double Sigmoid(double x){
        return (double)(1.0f/(1.0f+Math.exp((double)(-x))));    
 }
 
 private double purelin(double x) {
            return x;
 }
 
 private double arctanh(double x) {
            return (1/2 * Math.log((1+x)/(1-x)));
 }
}

/********************************************************
TestNNet.java
********************************************************/
class TestNNet{
public static void main(String[] args)
 {
            NNet aNet;     
           
            double goal = 0.01;
            int numEpochs = 100000; //number of training cycles
            int step4display = 10;
                       
            /*
                        //XOR function:
            String fileName = "xorTrainingData.txt";
            int numInputs = 2;
            int numHidden = 4;
            int numOutputs = 1;
            int numPatterns = 4;
            double hidLearnRate = 3.5;
            double outLearnRate = 3.5;
            double momentum = 0.7;
            int weightRandomRange = 1;
            String delim = " ";
            */
                       
                        //SINE function:
            String fileName = "sineTrainingData.txt";                      
            int numInputs = 1;
            int numHidden = 10;
            int numOutputs = 1;
            int numPatterns = 21;
            double hidLearnRate = 0.051;
            double outLearnRate = 0.051;
            double momentum = 0.7;
            double weightRandomRange = 2;
            String delim = ",";
           
                       
                        //initialize the net:
            aNet = new NNet(numInputs,numHidden,numOutputs, weightRandomRange);
           
                        //set the training data for the net:
    aNet.setTrainDataFromFile(fileName, delim, numInputs, numOutputs, numPatterns);
           
                        //train the network
            aNet.train(goal,hidLearnRate,outLearnRate, momentum, numEpochs, step4display);
           
                        //display results:
    for(int i = 0;i<numPatterns;i++){
        aNet.feedForward(aNet.trainInputs[i]);
       
        for (int k=0;k<numInputs;k++)
            System.out.print(aNet.trainInputs[i][k] + " ");
       
        for (int k=0;k<numOutputs;k++)
            System.out.print(": "+aNet.trainOutputs[i][k] + " --> ");
       
        System.out.print( aNet.outPred[0]+" (neural model)");
        System.out.println();
       
    }   
   }      
}

/****************************************************************
sineTrainingData.txt
****************************************************************/
-1.0000,0.0000
-0.9000,-0.3090
-0.8000,-0.5878
-0.7000,-0.8090
-0.6000,-0.9511
-0.5000,-1.0000
-0.4000,-0.9511
-0.3000,-0.8090
-0.2000,-0.5878
-0.1000,-0.3090
0.0000,0.0000
0.1000,0.3090
0.2000,0.5878
0.3000,0.8090
0.4000,0.9511
0.5000,1.0000
0.6000,0.9511
0.7000,0.8090
0.8000,0.5878
0.9000,0.3090
1.0000,0.0000

/*****************************************************************
xorTrainingData.txt
*****************************************************************/
1.0 1.0 0.0
1.0 0.0 1.0
0.0 1.0 1.0
0.0 0.0 0.0





1 user(s) are reading this topic

0 members, 1 guests, 0 anonymous users