Who is this post for ?
Software developers / machine learning engineers / data scients looking to get a pytorch model loaded into their java program.
Tools this post will be using :
- PyTorch
- Oracle Java 1.8
- Intellij
- Maven
- MacOS (10.13.6)
Quick description about the tools :
Feel free to scroll to the next section if you are fairly familiar with the tools
- PyTorch :
- A machine learning framework that competes with the likes of Keras and Tensorflow
- Developed by Facebook reasearch
- Can also be used in place of numpy in GPU enabled environments.
- Oracle Java 1.8 :
- A very popular programming language
- OpenJDK is oracle java’s open source implementation
- Intellij :
- A popular IDE used for Java
- Pretty power when it comes to debugging and addon integration
- Comes in 2 versions : Community and Ultimate
- Maven
- Build automation tool
- Makes dependency management pretty convinient
- MacOS
- A popular operating system
Source Code
If you like, source code for this entire post is available here and can be downloaded for free.
Getting started :
Create a new Intellij maven project
Go ahead and add a new package and a main class inside it to run our sample code
Update the pom.xml in your project to look like this :
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>playwithpytorch</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.8.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>repository</artifactId>
<version>0.4.1</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.8.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>osx-x86_64</classifier>
<version>1.6.0</version>
<scope>runtime</scope>
</dependency>
</dependencies>
</project>
Update Main.class to look like this
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
package com.playwithpytorch;
import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
public class Main {
public static void main(String[] args) throws IOException, ModelNotFoundException, MalformedModelException {
File inputFile = new File("/Users/siddharthmudgal/test/image.jpg");
Image img = ImageFactory.getInstance().fromFile(inputFile.toPath());
Criteria<Image, DetectedObjects> detectedObjectsCriteria =
Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION)
.setTypes(Image.class, DetectedObjects.class)
.optFilter("backbone","resnet50")
.build();
try (ZooModel<Image, DetectedObjects> imageDetectedObjectsZooModel =
ModelZoo.loadModel(detectedObjectsCriteria)) {
try (Predictor<Image, DetectedObjects> objectsPredictor= imageDetectedObjectsZooModel.newPredictor()) {
DetectedObjects detectedObjects = objectsPredictor.predict(img);
printDetectedObjectsToDisk(detectedObjects, img);
} catch (TranslateException e) {
e.printStackTrace();
}
}
}
public static void printDetectedObjectsToDisk(DetectedObjects detectedObjects , Image image) throws IOException {
Path outDir = Paths.get("/Users/siddharthmudgal/test/output.jpeg");
Image outputImage = image.duplicate(Image.Type.TYPE_INT_ARGB);
outputImage.drawBoundingBoxes(detectedObjects);
outputImage.save(Files.newOutputStream(outDir), "png");
}
}
Input image in my case looked like :
Output from the pytorch object detection looked like :
Although I am a little bummed that the label did not include Victoria Beckham, the model does produce decent object detection outputs.
I hope this post has helped you in some way.
Cheers!