Skip to content

Commit

Permalink
feat(#43): Convolutional Neural Networks
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Abramov committed Nov 4, 2024
1 parent 26c32c4 commit 02d031f
Show file tree
Hide file tree
Showing 11 changed files with 1,021 additions and 146 deletions.
105 changes: 105 additions & 0 deletions example/src/main/java/de/example/cnn/CnnExampleOnMNIST.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package de.example.cnn;

import de.edux.ml.api.ExecutionMode;
import de.edux.ml.mlp.core.network.NetworkBuilder;
import de.edux.ml.mlp.core.network.layers.ConvolutionalLayer;
import de.edux.ml.mlp.core.network.layers.DenseLayer;
import de.edux.ml.mlp.core.network.layers.FlattenLayer;
import de.edux.ml.mlp.core.network.layers.PoolingLayer;
import de.edux.ml.mlp.core.network.layers.ReLuLayer;
import de.edux.ml.mlp.core.network.layers.SoftmaxLayer;
import de.edux.ml.mlp.core.network.loader.Loader;
import de.edux.ml.mlp.core.network.loader.MetaData;
import de.edux.ml.mlp.core.network.loader.mnist.MnistLoader;

import java.io.File;

/**
* @author Samuel Abramov
*/
public class CnnExampleOnMNIST {
public static void main(String[] args) {
String trainImages =
"example"
+ File.separator
+ "datasets"
+ File.separator
+ "mnist"
+ File.separator
+ "train-images.idx3-ubyte";
String trainLabels =
"example"
+ File.separator
+ "datasets"
+ File.separator
+ "mnist"
+ File.separator
+ "train-labels.idx1-ubyte";
String testImages =
"example"
+ File.separator
+ "datasets"
+ File.separator
+ "mnist"
+ File.separator
+ "t10k-images.idx3-ubyte";
String testLabels =
"example"
+ File.separator
+ "datasets"
+ File.separator
+ "mnist"
+ File.separator
+ "t10k-labels.idx1-ubyte";

int batchSize = 100;
int epochs = 10;
float initialLearningRate = 0.1f;
float finalLearningRate = 0.0001f;

Loader trainLoader = new MnistLoader(trainImages, trainLabels, batchSize);
Loader testLoader = new MnistLoader(testImages, testLabels, batchSize);
MetaData trainMetaData = trainLoader.open();
int inputSize = trainMetaData.getInputSize();
int numberOfOutputClasses = trainMetaData.getNumberOfClasses();
trainLoader.close();

long startTime = System.currentTimeMillis();
new NetworkBuilder()
.addLayer(new ConvolutionalLayer(8, 3, 28, 28, 1)) // 8 Filter, 3x3, input 28x28, 1 grayscale channel
.addLayer(new ReLuLayer())
.addLayer(new PoolingLayer(8, 26, 26, 2, 2, 2)) // Pooling layer (2x2, stride 2)

.addLayer(new ConvolutionalLayer(16, 3, 13, 13, 8)) // 16 Filter, 3x3, input 13x13, 8 channels
.addLayer(new ReLuLayer())
.addLayer(new PoolingLayer(16, 11, 11, 2, 2, 2)) // Pooling layer (2x2, stride 2)

.addLayer(new FlattenLayer(16, 5, 5)) // Updated dimensions: 16 channels, 5x5 output
.addLayer(new DenseLayer(16 * 5 * 5, 128)) // Dense layer input from flattened convolution output
.addLayer(new ReLuLayer())
.addLayer(new DenseLayer(128, numberOfOutputClasses)) // Final dense layer with number of classes as output
.addLayer(new SoftmaxLayer())

// Hyperparameter configuration
.withBatchSize(batchSize)
.withLearningRates(initialLearningRate, finalLearningRate)
.withExecutionMode(ExecutionMode.SINGLE_THREAD)
.withEpochs(epochs)

// Build network
.build()
.printArchitecture()
.fit(trainLoader, testLoader)
.saveModel("cnn_mnist_trained.edux");

long endTime = System.currentTimeMillis();
System.out.println("Training took: " + (endTime - startTime) / 1000 + " seconds");

new NetworkBuilder()
.withExecutionMode(ExecutionMode.SINGLE_THREAD)
.withEpochs(2)
.withLearningRates(0.001f, 0.001f)
.loadModel("cnn_mnist_trained.edux")
.fit(trainLoader, testLoader);
}
}
4 changes: 2 additions & 2 deletions example/src/main/java/de/example/mlp/MlpExampleOnMNIST.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public static void main(String[] args) {

int batchSize = 100;
ExecutionMode singleThread = ExecutionMode.SINGLE_THREAD;
int epochs = 100;
int epochs = 5;
float initialLearningRate = 0.1f;
float finalLearningRate = 0.0001f;

Expand Down Expand Up @@ -81,7 +81,7 @@ public static void main(String[] args) {
// Loading a trained model
new NetworkBuilder()
.withExecutionMode(singleThread)
.withEpochs(5)
.withEpochs(2)
.withLearningRates(0.001f, 0.001f)
.loadModel("mnist_trained.edux")
.fit(trainLoader, testLoader);
Expand Down
2 changes: 1 addition & 1 deletion lib/src/main/java/de/edux/ml/api/ExecutionMode.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public enum ExecutionMode {
* Single-thread execution mode. In this mode, all batches are processed sequentially in a single
* thread.
*/
SINGLE_THREAD(1);
SINGLE_THREAD(1), MULTI_THREAD(6);

int threads = 1;

Expand Down
Loading

0 comments on commit 02d031f

Please sign in to comment.