Storch - GPU Accelerated
Deep Learning for Scala 3

Sören Brunk

USU AI Services

AI is Mainstream

AI is software - just programmed differently

AI needs software

D. Sculley et al.: Hidden Technical Debt in Machine Learning Systems
App A
Business logic
UI
Persistence
Auth
App B
Business Logic
UI
Persistence
Auth
App C
Business Logic
UI
Persistence
Auth

There’s a gap though

D. Sculley et al.: Hidden Technical Debt in Machine Learning Systems
App A
Business logic
UI
Persistence
Auth
App B
Business Logic
UI
Persistence
Auth
App C
Business Logic
UI
Persistence
Auth

Storch - a Scala API for PyTorch

Storch

import torch.*
import torch.nn.functional as F


class LeNet[D <: BFloat16 | Float32: Default]
extends nn.Module:
  val conv1 = register(nn.Conv2d(1, 6, 5))
  val conv2 = register(nn.Conv2d(6, 16, 5))
  val fc1 = register(nn.Linear(16 * 4 * 4, 120))
  val fc2 = register(nn.Linear(120, 84))
  val fc3 = register(nn.Linear(84, 10))


  def apply(i: Tensor[D]): Tensor[D] =
    var x = F.maxPool2d(F.relu(conv1(i)), (2, 2))
    x = F.maxPool2d(F.relu(conv2(x)), 2)
    x = x.view(-1, 16 * 4 * 4)
    x = F.relu(fc1(x))
    x = F.relu(fc2(x))
    x = fc3(x)
    x

PyTorch

import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Let’s train a model!

?

Cat

Dog

Cat

Demo

./ImageClassifier train \
  --dataset-dir cats-vs-dogs-sample \
  --base-model ResNet34

Storch Features

Storch aims to offer the same core features as PyTorch:

  • GPU accelerated tensor operations
  • Automatic differentiation
  • A neural network API for building and training machine learning models

Tensors

Tensors are homogeneous multidimensional rectangular arrays of numbers

Creating Tensors

From numbers

torch.Tensor(1)
res3: Tensor[Int32] = tensor dtype=int32, shape=[], device=CPU 
1
torch.Tensor(0.5f)
res4: Tensor[Float32] = tensor dtype=float32, shape=[], device=CPU 
0.5000

From sequences

torch.Tensor(Seq(true, false, true))
res5: Tensor[Bool] = tensor dtype=bool, shape=[3], device=CPU 
[true, false, true]
torch.Tensor(Seq(Seq(1, 2, 3), Seq(4, 5, 6)))
res6: Tensor[Int32] = tensor dtype=int32, shape=[2, 3], device=CPU 
[[1, 2, 3],
 [4, 5, 6]]

Creating Tensors

From images (via the Scrimage library)

import torchvision.transforms.functional.toTensor
import com.sksamuel.scrimage.ImmutableImage
val image = ImmutableImage.loader().fromFile("img/cat1.jpg")
val imageTensor = toTensor(image)
image: ImmutableImage = Image [width=325, height=291, type=5]
imageTensor: Tensor[Float32] = tensor dtype=float32, shape=[3, 291, 325], device=CPU 
[[[1.0000, 1.0000, 0.9961, ..., 0.9961, 1.0000, 1.0000],
  [0.9922, 0.9843, 0.9961, ..., 1.0000, 0.9961, 1.0000],
  [0.9961, 0.9961, 0.8941, ..., 0.8471, 0.9647, 0.9922],
  ...,
  [1.0000, 1.0000, 0.7216, ..., 0.5020, 0.8431, 0.9961],
  [1.0000, 0.9882, 0.6000, ..., 0.5255, 0.8431, 1.0000],
  [1.0000, 1.0000, 0.9961, ..., 0.9843, 1.0000, 0.9882]],
 [[1.0000, 1.0000, 1.0000, ..., 0.9882, 0.9961, 0.9961],
  [0.9922, 0.9922, 1.0000, ..., 1.0000, 0.9882, 0.9922],
  [1.0000, 1.0000, 0.8980, ..., 0.8431, 0.9608, 0.9843],
  ...,
  [0.9922, 1.0000, 0.7137, ..., 0.5020, 0.8431, 0.9922],
  [1.0000, 0.9765, 0.5882, ..., 0.5255, 0.8392, 1.0000],
  [1.0000, 1.0000, 0.9922, ..., 0.9843, 1.0000, 0.9922]],
 [[0.9843, 0.9843, 1.0000, ..., 0.9922, 1.0000, 1.0000],
  [0.9843, 0.9804, 1.0000, ..., 0.9922, 0.9922, 0.9961],
  [1.0000, 1.0000, 0.9059, ..., 0.8353, 0.9529, 0.9882],
  ...,
  [0.9647, 0.9647, 0.6627, ..., 0.4549, 0.8118, 0.9725],
  [0.9725, 0.9412, 0.5529, ..., 0.4941, 0.8196, 0.9843],
  [0.9922, 0.9843, 0.9765, ..., 0.9529, 0.9804, 0.9725]]]

Tensor Operations

val m = torch.rand(size=Seq(3,3))
m: Tensor[Float32] = tensor dtype=float32, shape=[3, 3], device=CPU 
[[0.4963, 0.7682, 0.0885],
 [0.1320, 0.3074, 0.6341],
 [0.4901, 0.8964, 0.4556]]
val t = torch.arange(0, 3, dtype=float32)
t: Tensor[Float32] = tensor dtype=float32, shape=[3], device=CPU 
[0.0000, 1.0000, 2.0000]
m matmul t
res12: Tensor[Float32] = tensor dtype=float32, shape=[3], device=CPU 
[0.9452, 1.5756, 1.8077]
t + t
res13: Tensor[Float32] = tensor dtype=float32, shape=[3], device=CPU 
[0.0000, 2.0000, 4.0000]
t * 2 // Seq(0, 1, 2).map(x => x * 2)
res14: Tensor[Float32] = tensor dtype=float32, shape=[3], device=CPU 
[0.0000, 2.0000, 4.0000]

Using the GPU

import torch.Device.{CPU, CUDA}
val device = if torch.cuda.isAvailable then CUDA else CPU
val m = torch.rand(size=Seq(3,3), device=device)
m: Tensor[Float32] = tensor dtype=float32, shape=[3, 3], device=CUDA 
[[0.3990, 0.5167, 0.0249],
 [0.9401, 0.9459, 0.7967],
 [0.4150, 0.8203, 0.2290]]
val t = torch.arange(0, 3, dtype=float32, device=device)
t: Tensor[Float32] = tensor dtype=float32, shape=[3], device=CUDA 
[0.0000, 1.0000, 2.0000]
m matmul t
res15: Tensor[Float32] = tensor dtype=float32, shape=[3], device=CUDA 
[0.5665, 2.5393, 1.2783]

https://storch.dev/installation.html#enable-gpu-support

More Ops

https://storch.dev/api/torch.html#Grouped-members

Building models

 

The simplest model

Let’s learn the identity function!

def f(x: Int) = x

The simplest model

Let’s learn the identity function!

class Model:
  // our model weights (randomly initialized)
  val weights = torch.rand(Seq(1))
  // our model architecture
  def apply(x: Tensor[Int32]) = x * weights

val model = new Model()
model.weights
res20: Tensor[Float32] = tensor dtype=float32, shape=[1], device=CPU 
[0.1490]

The simplest model

val x = torch.Tensor(39)
val y = x // our "label" here is just the input
val z = model(x)
z: Tensor[Float32] = tensor dtype=float32, shape=[1], device=CPU 
[5.8126]

Let’s see how far we’re off

val loss = (z - y).abs
loss: Tensor[Float32] = tensor dtype=float32, shape=[1], device=CPU 
[33.1874]

Can we change the weights to make the loss go down?

Automatic Differentiation

We compute the derivative (gradient) of our loss function

class Model:
  // our model weights (randomly initialized)
  val weights = torch.rand(Seq(1))
  // our model architecture
  def apply(x: Tensor[Int32]) = x * weights

val model = new Model()

Automatic Differentiation

We compute the derivative (gradient) of our loss function

class Model:
  // our model weights (randomly initialized)
  // enable gradient tracking for our weights
  val weights = torch.rand(Seq(1), requiresGrad=true)
  // our model architecture
  def apply(x: Tensor[Int32]) = x * weights

val model = new Model()

Automatic Differentiation

val x = torch.Tensor(39)
val y = x // our "label" here is just the input
val z = model(x)
z: Tensor[Float32] = tensor dtype=float32, shape=[1], device=CPU 
[5.8126]
val loss = (z - y).abs
loss: Tensor[Float32] = tensor dtype=float32, shape=[1], device=CPU 
[33.1874]

Automatic Differentiation

Backpropagation (compute the gradients w.r.t. the weights)

loss.backward()
model.weights.grad
res62_1: Tensor[torch.Tensor[torch.Float32 | torch.Undefined]] = tensor dtype=float32, shape=[1], device=CPU 
[-39.0000]

Update weights in the opposite direction of the gradient

model.weights
res63: Tensor[Float32] = tensor dtype=float32, shape=[1], device=CPU 
[0.1490]
noGrad {
  model.weights -= (model.weights.grad * 0.001f)
}
res64: Tensor[Float32] = tensor dtype=float32, shape=[1], device=CPU 
[0.1880]

Gradient Descent

for step <- 1 to 20 do
  val x = torch.randint(low=0, high=100, size=Seq(1), dtype=int32) // input
  val y = x                   // expected output
  val z = model(x)            // predicted output
  val loss = (z - y).abs      // error/loss
  loss.backward()             // compute gradient
  noGrad {                    // disable gradient tracking locally
    // update weight based on gradient
    model.weights -= (model.weights.grad * 0.001f)
    model.weights.grad.zero()
  }
step x z loss old weight gradient new weight
1 44 08 36 0.1880 -44.00 0.2320
step x z loss old weight gradient new weight
2 39 09 30 0.2320 -39.00 0.2710
step x z loss old weight gradient new weight
3 33 09 24 0.2710 -33.00 0.3040
step x z loss old weight gradient new weight
4 60 18 42 0.3040 -60.00 0.3640
step x z loss old weight gradient new weight
5 63 23 40 0.3640 -63.00 0.4270
step x z loss old weight gradient new weight
6 79 34 45 0.4270 -79.00 0.5060
step x z loss old weight gradient new weight
7 27 14 13 0.5060 -27.00 0.5330
step x z loss old weight gradient new weight
8 03 02 01 0.5330 -3.00 0.5360
step x z loss old weight gradient new weight
9 97 52 45 0.5360 -97.00 0.6330
step x z loss old weight gradient new weight
10 83 53 30 0.6330 -83.00 0.7160
step x z loss old weight gradient new weight
11 01 01 00 0.7160 -1.00 0.7170
step x z loss old weight gradient new weight
12 66 47 19 0.7170 -66.00 0.7830
step x z loss old weight gradient new weight
13 56 44 12 0.7830 -56.00 0.8390
step x z loss old weight gradient new weight
14 99 83 16 0.8390 -99.00 0.9380
step x z loss old weight gradient new weight
15 78 73 05 0.9380 -78.00 1.0160
step x z loss old weight gradient new weight
16 76 77 01 1.0160 76.00 0.9400
step x z loss old weight gradient new weight
17 56 53 03 0.9400 -56.00 0.9960
step x z loss old weight gradient new weight
18 68 68 00 0.9960 -68.00 1.0640
step x z loss old weight gradient new weight
19 94 100 06 1.0640 94.00 0.9700
step x z loss old weight gradient new weight
20 33 32 01 0.9700 -33.00 1.0030

That’s how all AIs are trained!

The neural network API

class Model:
  val weights = torch.rand(Seq(1), requiresGrad=true)
  def apply(x: Tensor[Int32]) = x * weights

val model = new Model()

Let’s use the torch.nn API instead

import torch.nn

val model = nn.Linear(inFeatures=1, outFeatures=1) // Linear is a nn.Module

Scale up

val model = nn.Linear(inFeatures=12, outFeatures=1)

Scale up

val model = nn.Linear(inFeatures=12, outFeatures=8)

Scale up

val model = nn.Sequential( // Sequential is another nn.Module, a container
    nn.Linear(12, 8),
    nn.ReLU(),             // ReLU is a non-linear activation function
    nn.Linear(8, 2)
)

Scale up

val model = nn.Sequential( // Sequential is another nn.Module, a container
    nn.Linear(12, 8),
    nn.ReLU(),             // ReLU is a non-linear activation function
    nn.Linear(8, 8),
    nn.ReLU(),
    nn.Linear(8, 2)
)

Specialize

 

import torch.nn.functional as F
import torch.nn

class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module:
  val conv1 = register(nn.Conv2d(1, 6, 5))
  val conv2 = register(nn.Conv2d(6, 16, 5))
  val fc1 = register(nn.Linear(16 * 4 * 4, 120))
  val fc2 = register(nn.Linear(120, 84))
  val fc3 = register(nn.Linear(84, 10))

  def apply(i: Tensor[D]): Tensor[D] =
    var x = F.maxPool2d(F.relu(conv1(i)), (2, 2))
    x = F.maxPool2d(F.relu(conv2(x)), 2)
    x = x.view(-1, 16 * 4 * 4)
    x = F.relu(fc1(x))
    x = F.relu(fc2(x))
    x = fc3(x)
    x

Scale up more

/** ResNet architecture implementations */
object resnet:
  /** 3x3 convolution with padding */
  def conv3x3[D <: BFloat16 | Float32 | Float64: Default](
      inPlanes: Int,
      outPlanes: Int,
      stride: Int = 1,
      groups: Int = 1,
      dilation: Int = 1
  ): Conv2d[D] =
    Conv2d[D](
      inPlanes,
      outPlanes,
      kernelSize = 3,
      stride = stride,
      padding = dilation,
      groups = groups,
      bias = false,
      dilation = dilation
    )

  /** 1x1 convolution */
  def conv1x1[D <: FloatNN: Default](inPlanes: Int, outPlanes: Int, stride: Int = 1): Conv2d[D] =
    Conv2d[D](inPlanes, outPlanes, kernelSize = 1, stride = stride, bias = false)

  sealed abstract class BlockBuilder:
    val expansion: Int
    def apply[D <: BFloat16 | Float32 | Float64: Default](
        inplanes: Int,
        planes: Int,
        stride: Int = 1,
        downsample: Option[TensorModule[D]] = None,
        groups: Int = 1,
        baseWidth: Int = 64,
        dilation: Int = 1,
        normLayer: (Int => HasWeight[D] & TensorModule[D])
    ): TensorModule[D] = this match
      case BasicBlock =>
        new BasicBlock(inplanes, planes, stride, downsample, groups, baseWidth, dilation, normLayer)
      case Bottleneck =>
        new Bottleneck(inplanes, planes, stride, downsample, groups, baseWidth, dilation, normLayer)

  object BasicBlock extends BlockBuilder:
    override val expansion: Int = 1

  object Bottleneck extends BlockBuilder:
    override val expansion: Int = 4

  class BasicBlock[D <: BFloat16 | Float32 | Float64: Default](
      inplanes: Int,
      planes: Int,
      stride: Int = 1,
      downsample: Option[TensorModule[D]] = None,
      groups: Int = 1,
      baseWidth: Int = 64,
      dilation: Int = 1,
      normLayer: => (Int => TensorModule[D])
  ) extends TensorModule[D] {
    import BasicBlock.expansion

    if groups != 1 || baseWidth != 64 then
      throw new IllegalArgumentException("BasicBlock only supports groups=1 and baseWidth=64")
    if dilation > 1 then throw new NotImplementedError("Dilation > 1 not supported in BasicBlock")
    // Both conv1 and downsample layers downsample the input when stride != 1
    val conv1 = register(conv3x3[D](inplanes, planes, stride))
    val bn1 = register(normLayer(planes))
    val relu = register(ReLU(inplace = true))
    val conv2 = register(conv3x3[D](planes, planes))
    val bn2 = register(normLayer(planes))
    downsample.foreach(downsample => register(downsample)(using Name("downsample")))

    def apply(x: Tensor[D]): Tensor[D] =
      var identity = x

      var out = conv1(x)
      out = bn1(out)
      out = relu(out)

      out = conv2(out)
      out = bn2(out)

      downsample.foreach { downsample =>
        identity = downsample(x)
      }

      out += identity
      out = relu(out)

      out
    override def toString(): String = getClass().getSimpleName()
  }

  /** Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    * while original implementation places the stride at the first 1x1 convolution(self.conv1)
    * according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    * This variant is also known as ResNet V1.5 and improves accuracy according to
    * https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
    */
  class Bottleneck[D <: BFloat16 | Float32 | Float64: Default](
      inplanes: Int,
      planes: Int,
      stride: Int = 1,
      downsample: Option[TensorModule[D]] = None,
      groups: Int = 1,
      baseWidth: Int = 64,
      dilation: Int = 1,
      normLayer: (Int => HasWeight[D] & TensorModule[D])
  ) extends TensorModule[D]:
    import Bottleneck.expansion

    val width = (planes * (baseWidth / 64.0)).toInt * groups
    // Both self.conv2 and self.downsample layers downsample the input when stride != 1
    val conv1 = register(conv1x1(inplanes, width))
    val bn1 = register(normLayer(width))
    val conv2 = register(conv3x3(width, width, stride, groups, dilation))
    val bn2 = register(normLayer(width))
    val conv3 = register(conv1x1(width, planes * expansion))
    val bn3 = register(normLayer(planes * expansion))
    val relu = register(ReLU(inplace = true))
    downsample.foreach(downsample => register(downsample)(using Name("downsample")))

    def apply(x: Tensor[D]): Tensor[D] =
      var identity = x

      var out = conv1(x)
      out = bn1(out)
      out = relu(out)

      out = conv2(out)
      out = bn2(out)
      out = relu(out)

      out = conv3(out)
      out = bn3(out)

      downsample.foreach { downsample =>
        identity = downsample(x)
      }

      out += identity
      out = relu(out)

      out
    override def toString(): String = getClass().getSimpleName()

  class ResNet[D <: BFloat16 | Float32 | Float64](
      block: BlockBuilder,
      layers: Seq[Int],
      numClasses: Int = 1000,
      zeroInitResidual: Boolean = false,
      groups: Int = 1,
      widthPerGroup: Int = 64,
      // each element in the tuple indicates if we should replace
      // the 2x2 stride with a dilated convolution instead
      replaceStrideWithDilation: (Boolean, Boolean, Boolean) = (false, false, false)
  )(using Default[D])(
      normLayer: (Int => HasWeight[D] & TensorModule[D]) =
        (numFeatures => BatchNorm2d[D](numFeatures))
  ) extends Module {
    var inplanes = 64
    var dilation = 1
    val baseWidth = widthPerGroup
    val conv1 = register(Conv2d(3, inplanes, kernelSize = 7, stride = 2, padding = 3, bias = false))
    val bn1 = register(normLayer(inplanes))
    val relu = register(ReLU(inplace = true))
    val maxpool = register(MaxPool2d(kernelSize = 3, stride = Some(2), padding = 1))
    val layer1 = register(makeLayer(block, 64, layers(0)))
    val layer2 = register(
      makeLayer(block, 128, layers(1), stride = 2, dilate = replaceStrideWithDilation(0))
    )
    val layer3 = register(
      makeLayer(block, 256, layers(2), stride = 2, dilate = replaceStrideWithDilation(1))
    )
    val layer4 = register(
      makeLayer(block, 512, layers(3), stride = 2, dilate = replaceStrideWithDilation(2))
    )
    val avgpool = register(AdaptiveAvgPool2d((1, 1)))
    val fc = register(Linear(512 * block.expansion, numClasses))

    for (m <- modules)
      m match
        case m: Conv2d[?] =>
          nn.init.kaimingNormal_(m.weight, mode = Mode.FanOut, nonlinearity = NonLinearity.ReLU)
        case m: BatchNorm2d[?] =>
          nn.init.constant_(m.weight, 1)
          nn.init.constant_(m.bias, 0)
        case m: nn.GroupNorm[?] =>
          nn.init.constant_(m.weight, 1)
          nn.init.constant_(m.bias, 0)
        case _ =>

    // Zero-initialize the last BN in each residual branch,
    // so that the residual branch starts with zeros, and each residual block behaves like an identity.
    // This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
    if zeroInitResidual then
      for (m <- modules)
        m match
          case m: Bottleneck[?] if m.bn3.weight.dtype != DType.undefined =>
            nn.init.constant_(m.bn3.weight, 0)
          case m: Bottleneck[?] if m.bn2.weight.dtype != DType.undefined =>
            nn.init.constant_(m.bn2.weight, 0)

    private def makeLayer(
        block: BlockBuilder,
        planes: Int,
        blocks: Int,
        stride: Int = 1,
        dilate: Boolean = false
    ): Sequential[D] = {
      var downsample: Option[TensorModule[D]] = None
      val previous_dilation = this.dilation
      var _stride: Int = stride
      if dilate then
        this.dilation *= stride
        _stride = 1
      if _stride != 1 || inplanes != planes * block.expansion then
        downsample = Some(
          Sequential(
            conv1x1(inplanes, planes * block.expansion, _stride),
            normLayer(planes * block.expansion)
          )
        )

      var layers = Vector[TensorModule[D]]()
      layers = layers :+ block(
        this.inplanes,
        planes,
        _stride,
        downsample,
        this.groups,
        this.baseWidth,
        previous_dilation,
        normLayer
      )

      inplanes = planes * block.expansion
      for (_ <- 1 until blocks)
        layers = layers :+
          block(
            this.inplanes,
            planes,
            groups = this.groups,
            baseWidth = this.baseWidth,
            dilation = this.dilation,
            normLayer = normLayer
          )

      Sequential(layers*)
    }

    private inline def forwardImpl(_x: Tensor[D]): Tensor[D] =
      var x = conv1(_x)
      x = bn1(x)
      x = relu(x)
      x = maxpool(x)

      x = layer1(x)
      x = layer2(x)
      x = layer3(x)
      x = layer4(x)

      x = avgpool(x)
      x = x.flatten(1)
      fc(x)

    def apply(x: Tensor[D]): Tensor[D] = forwardImpl(x)
  }

  private val weightsBaseUrl =
    "https://github.com/sbrunk/storch/releases/download/pretrained-weights/"

  case class Weights(
      url: String,
      transforms: Presets.ImageClassification
  )

  abstract class ResNetFactory:
    def apply[D <: BFloat16 | Float32 | Float64: Default](numClasses: Int = 1000): ResNet[D]
    val DEFAULT: Weights

  /** ResNet-18 from [Deep Residual Learning for Image
    * Recognition](https://arxiv.org/pdf/1512.03385.pdf).
    */
  object ResNet18 extends ResNetFactory:
    def apply[D <: BFloat16 | Float32 | Float64: Default](numClasses: Int = 1000) =
      ResNet[D](BasicBlock, Seq(2, 2, 2, 2), numClasses)()
    val IMAGENET1K_V1 = Weights(
      url = weightsBaseUrl + "resnet18-f37072fd.pth",
      Presets.ImageClassification(cropSize = 224)
    )
    val DEFAULT = IMAGENET1K_V1

  /** ResNet-34 from [Deep Residual Learning for Image
    * Recognition](https://arxiv.org/pdf/1512.03385.pdf).
    */
  object ResNet34 extends ResNetFactory:
    def apply[D <: BFloat16 | Float32 | Float64: Default](numClasses: Int = 1000) =
      ResNet[D](BasicBlock, Seq(3, 4, 6, 3), numClasses = numClasses)()
    val IMAGENET1K_V1 = Weights(
      url = weightsBaseUrl + "resnet34-b627a593.pth",
      transforms = Presets.ImageClassification(cropSize = 224)
    )
    val DEFAULT = IMAGENET1K_V1

  /** ResNet-50 from [Deep Residual Learning for Image
    * Recognition](https://arxiv.org/pdf/1512.03385.pdf)
    */
  object ResNet50 extends ResNetFactory:
    def apply[D <: BFloat16 | Float32 | Float64: Default](numClasses: Int = 1000) =
      ResNet(Bottleneck, Seq(3, 4, 6, 3), numClasses = numClasses)()
    val IMAGENET1K_V1 = Weights(
      url = weightsBaseUrl + "resnet50-0676ba61.pth",
      transforms = Presets.ImageClassification(cropSize = 224)
    )
    val IMAGENET1K_V2 = Weights(
      url = weightsBaseUrl + "resnet50-11ad3fa6.pth",
      transforms = Presets.ImageClassification(cropSize = 224, resizeSize = 232)
    )
    val DEFAULT = IMAGENET1K_V2

  /** ResNet-101 from [Deep Residual Learning for Image
    * Recognition](https://arxiv.org/pdf/1512.03385.pdf)
    */
  object ResNet101 extends ResNetFactory:
    def apply[D <: BFloat16 | Float32 | Float64: Default](numClasses: Int = 1000) =
      ResNet(Bottleneck, Seq(3, 4, 23, 3), numClasses = numClasses)()
    val IMAGENET1K_V1 = Weights(
      url = weightsBaseUrl + "resnet101-63fe2227.pth",
      transforms = Presets.ImageClassification(cropSize = 224)
    )
    val IMAGENET1K_V2 = Weights(
      url = weightsBaseUrl + "resnet101-cd907fc2.pth",
      transforms = Presets.ImageClassification(cropSize = 224, resizeSize = 232)
    )
    val DEFAULT = IMAGENET1K_V2

  /** ResNet-152 from [Deep Residual Learning for Image
    * Recognition](https://arxiv.org/pdf/1512.03385.pdf)
    */
  object ResNet152 extends ResNetFactory:
    def apply[D <: BFloat16 | Float32 | Float64: Default](numClasses: Int = 1000) =
      ResNet(Bottleneck, Seq(3, 8, 36, 3), numClasses = numClasses)()
    val IMAGENET1K_V1 = Weights(
      url = weightsBaseUrl + "resnet152-394f9c45.pth",
      transforms = Presets.ImageClassification(cropSize = 224)
    )
    val IMAGENET1K_V2 = Weights(
      url = weightsBaseUrl + "resnet152-f82ba261.pth",
      transforms = Presets.ImageClassification(cropSize = 224, resizeSize = 232)
    )
    val DEFAULT = IMAGENET1K_V2

  enum ResNetVariant(val factory: ResNetFactory):
    case ResNet18 extends ResNetVariant(resnet.this.ResNet18)
    case ResNet34 extends ResNetVariant(resnet.this.ResNet34)
    case ResNet50 extends ResNetVariant(resnet.this.ResNet50)
    case ResNet101 extends ResNetVariant(resnet.this.ResNet101)
    case ResNet152 extends ResNetVariant(resnet.this.ResNet152)

Training our image classifier

def dataLoader: Iterator[(Tensor[Float32], Tensor[Int64])] = ???
val model: ResNet[Float32] = ResNet34(numClasses = 2)
val optimizer = Adam(model.parameters, lr = 1e-5)
val lossFn = nn.loss.CrossEntropyLoss()

for epoch <- 1 to numEpochs do // loop through data
  for (input, label) <- dataLoader do // loop through batches
    optimizer.zeroGrad() // reset optimizer
    // ensure timely deallocation
    Using.resource(new PointerScope()) { _ =>
      val prediction = model(input)
      // compute loss (how good was the prediction)
      val loss = lossFn(prediction, label)
      loss.backward()  // compute the gradients of the loss
      optimizer.step() // update weights based on the gradients
    }

Training our image classifier

def dataLoader: Iterator[(Tensor[Float32], Tensor[Int64])] = ???
val model: ResNet[Float32] = ResNet34(numClasses = 2)
val optimizer = Adam(model.parameters, lr = 1e-5)
val lossFn = nn.loss.CrossEntropyLoss()

for epoch <- 1 to numEpochs do // loop through data
  for (input, label) <- dataLoader do // loop through batches
    optimizer.zeroGrad() // reset optimizer
    // ensure timely deallocation
    Using.resource(new PointerScope()) { _ =>
      val prediction = model(input)
      // compute loss (how good was the prediction)
      val loss = lossFn(prediction, label)
      loss.backward()  // compute the gradients of the loss
      optimizer.step() // update weights based on the gradients
    }

Training our image classifier

def dataLoader: Iterator[(Tensor[Float32], Tensor[Int64])] = ???
val model: ResNet[Float32] = ResNet34(numClasses = 2)
val optimizer = Adam(model.parameters, lr = 1e-5)
val lossFn = nn.loss.CrossEntropyLoss()

for epoch <- 1 to numEpochs do // loop through data
  for (input, label) <- dataLoader do // loop through batches
    optimizer.zeroGrad() // reset optimizer
    // ensure timely deallocation
    Using.resource(new PointerScope()) { _ =>
      val prediction = model(input)
      // compute loss (how good was the prediction)
      val loss = lossFn(prediction, label)
      loss.backward()  // compute the gradients of the loss
      optimizer.step() // update weights based on the gradients
    }

Storch under the hood

A lot of PyTorch is written in C++

Source: https://github.com/pytorch/pytorch/

LibTorch

How can we use this from Scala?

  • LibTorch is a large and complex C++ codebase
  • Do we have to write all that JNI code by hand? 🙀
  • JavaCPP to the rescue!

 

How can we use this from Scala?

  • 🚀 The JavaCPP PyTorch bindings give us the full power of LibTorch from Java
  • 😰 Writing C++ in Java syntax is not always pleasant though
  • 😌 Storch provides an idiomatic Scala API on top

Safer Tensors with Types

PyTorch

torch.mean(torch.tensor(1))
# RuntimeError: mean(): could not infer output dtype.
# Input dtype must be either a floating point or complex dtype. Got: Long

Storch

def mean[D <: FloatNN | ComplexNN]( ... ) = ...
val t = torch.Tensor(1)
t: Tensor[Int32] = tensor dtype=int32, shape=[], device=CPU 
1
torch.mean(t)
// error:
// torch.mean(t)
//            ^
// Found:    torch.Tensor[torch.Int32]
// Required: torch.Tensor[D]
// where:    D is a type variable with constraint <: torch.FloatNN | torch.ComplexNN

Safer Tensors with Types

DType.scala

/** Type of the output tensor based on PyTorch type promotion rules
  *
  * This is a type-level implementation of the PyTorch op data type promotion rules via match types.
  */
type Promoted[T <: DType, U <: DType] <: DType = (T, U) match
  case (T, T)                                    => T
  case (U, U)                                    => U
  // ...
  case (Int64, U)                                => U
  case (T, Int64)                                => T
  case (Float16, BFloat16) | (BFloat16, Float16) => Float32
  case (Float16, U)                              => U
  case (T, Float16)                              => T
  // ...

Tensor.scala

class Tensor[D <: DType]:
  def +[D2 <: DType](other: Tensor[D2]): Tensor[Promoted[D, D2]] = ???

Safer Tensors with Types

DType.scala

/** Promoted type for tensor division */
type Div[T <: DType, U <: DType] <: DType = (T, U) match
  case (BitwiseNN, BitwiseNN) => Float32
  case _                      => Promoted[T, U]

/** Promoted type for elementwise tensor sum */
type Sum[D <: DType] <: DType = D match
  case BitwiseNN => Int64
  case D         => D

Tensor.scala

class Tensor[D <: DType]:
  /** Divides each element of this tensor by the corresponding element of `other`. * */
  def /[D2 <: DType](other: Tensor[D2]): Tensor[Div[D, D2]] = div(other)

  /** Returns the sum of all elements of this tensor. */
  def sum: Tensor[Sum[D]] = Tensor(native.sum())

What’s next?

  • In progress
    • Add more model architectures (transformers!)
    • Improve op coverage
  • Ideas
    • Improve performance (i.e. mixed precision training)
    • Improve safety
      • Compile-time shapes & devices
      • Immutable Tensors
    • Higher-level training abstractions

We want YOU!

🫵

https://storch.dev

Thank you!