Who is this post for ?
Software developers / machine learning engineers / data scients looking to get a pytorch model loaded into their java program.
Tools this post will be using :
- PyTorch
- Oracle Java 1.8
- Intellij
- Maven
- MacOS (10.13.6)
Quick description about the tools :
Feel free to scroll to the next section if you are fairly familiar with the tools
- PyTorch :- A machine learning framework that competes with the likes of Keras and Tensorflow
- Developed by Facebook reasearch
- Can also be used in place of numpy in GPU enabled environments.
 
- Oracle Java 1.8 :- A very popular programming language
- OpenJDK is oracle java’s open source implementation
 
- Intellij :- A popular IDE used for Java
- Pretty power when it comes to debugging and addon integration
- Comes in 2 versions : Community and Ultimate
 
- Maven- Build automation tool
- Makes dependency management pretty convinient
 
- MacOS- A popular operating system
 
Source Code
If you like, source code for this entire post is available here and can be downloaded for free.
Getting started :
Create a new Intellij maven project




Go ahead and add a new package and a main class inside it to run our sample code




Update the pom.xml in your project to look like this :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <groupId>org.example</groupId>
    <artifactId>playwithpytorch</artifactId>
    <version>1.0-SNAPSHOT</version>
    <properties>
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
    </properties>
    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.8.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>repository</artifactId>
            <version>0.4.1</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>0.8.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-cpu</artifactId>
            <classifier>osx-x86_64</classifier>
            <version>1.6.0</version>
            <scope>runtime</scope>
        </dependency>
    </dependencies>
</project>
Update Main.class to look like this
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
package com.playwithpytorch;
import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
public class Main {
    public static void main(String[] args) throws IOException, ModelNotFoundException, MalformedModelException {
        File inputFile = new File("/Users/siddharthmudgal/test/image.jpg");
        Image img = ImageFactory.getInstance().fromFile(inputFile.toPath());
        Criteria<Image, DetectedObjects> detectedObjectsCriteria =
                Criteria.builder()
                .optApplication(Application.CV.OBJECT_DETECTION)
                .setTypes(Image.class, DetectedObjects.class)
                .optFilter("backbone","resnet50")
                .build();
        try (ZooModel<Image, DetectedObjects> imageDetectedObjectsZooModel =
                     ModelZoo.loadModel(detectedObjectsCriteria)) {
            try (Predictor<Image, DetectedObjects> objectsPredictor= imageDetectedObjectsZooModel.newPredictor()) {
                DetectedObjects detectedObjects = objectsPredictor.predict(img);
                printDetectedObjectsToDisk(detectedObjects, img);
            } catch (TranslateException e) {
                e.printStackTrace();
            }
        }
    }
    public static void printDetectedObjectsToDisk(DetectedObjects detectedObjects , Image image) throws IOException {
        Path outDir = Paths.get("/Users/siddharthmudgal/test/output.jpeg");
        Image outputImage = image.duplicate(Image.Type.TYPE_INT_ARGB);
        outputImage.drawBoundingBoxes(detectedObjects);
        outputImage.save(Files.newOutputStream(outDir), "png");
    }
}
Input image in my case looked like :

Output from the pytorch object detection looked like :

 Although I am a little bummed that the label did not include Victoria Beckham, the model does produce decent object detection outputs.
I hope this post has helped you in some way.
Cheers!
