@@ -8,20 +8,13 @@ public class FeedForwardNeuralNetwork implements FunctionApproximator {
88 private final Layer hiddenLayer ;
99 private final Layer outputLayer ;
1010
11- private final double learningRate , momentum ;
12-
13- private final NNTrainingScheme trainingScheme ;
11+ private NNTrainingScheme trainingScheme ;
1412
1513 /*
16- * constructor to be used for non testing code for now assume that config
17- * contains learning rate, momentum parameter, and number of epochs. change
18- * this later to accomodate varied learning schemes like early stopping
14+ * constructor to be used for non testing code.
1915 */
2016 public FeedForwardNeuralNetwork (NNConfig config ) {
2117
22- learningRate = config .getParameterAsDouble ("learning_rate" );
23- momentum = config .getParameterAsDouble ("momentum" );
24-
2518 int numberOfInputNeurons = config
2619 .getParameterAsInteger ("number_of_inputs" );
2720 int numberOfHiddenNeurons = config
@@ -41,7 +34,6 @@ public FeedForwardNeuralNetwork(NNConfig config) {
4134 outputLayer = new Layer (numberOfOutputNeurons , numberOfHiddenNeurons ,
4235 lowerLimitForWeights , upperLimitForWeights ,
4336 new PureLinearActivationFunction ());
44- trainingScheme = new BackPropLearning (this , learningRate , momentum );
4537
4638 }
4739
@@ -52,14 +44,13 @@ public FeedForwardNeuralNetwork(NNConfig config) {
5244 */
5345 public FeedForwardNeuralNetwork (Matrix hiddenLayerWeights ,
5446 Vector hiddenLayerBias , Matrix outputLayerWeights ,
55- Vector outputLayerBias , double learningRate , double momentum ) {
56- this .learningRate = learningRate ;
57- this .momentum = momentum ;
47+ Vector outputLayerBias ) {
48+
5849 hiddenLayer = new Layer (hiddenLayerWeights , hiddenLayerBias ,
5950 new LogSigActivationFunction ());
6051 outputLayer = new Layer (outputLayerWeights , outputLayerBias ,
6152 new PureLinearActivationFunction ());
62- trainingScheme = new BackPropLearning ( this , learningRate , momentum );
53+
6354 }
6455
6556 public void processError (Vector error ) {
@@ -104,4 +95,9 @@ public Layer getOutputLayer() {
10495 return outputLayer ;
10596 }
10697
98+ public void setTrainingScheme (NNTrainingScheme trainingScheme ) {
99+ this .trainingScheme = trainingScheme ;
100+ trainingScheme .setNeuralNetwork (this );
101+ }
102+
107103}
0 commit comments