.Tensor(1) torch
res3: Tensor[Int32] = tensor dtype=int32, shape=[], device=CPU 1
USU AI Services
App A |
---|
Business logic |
UI |
Persistence |
Auth |
… |
App B |
---|
Business Logic |
UI |
Persistence |
Auth |
… |
App C |
---|
Business Logic |
UI |
Persistence |
Auth |
… |
App A |
---|
Business logic |
UI |
Persistence |
Auth |
… |
App B |
---|
Business Logic |
UI |
Persistence |
Auth |
… |
App C |
---|
Business Logic |
UI |
Persistence |
Auth |
… |
Storch
import torch.*
import torch.nn.functional as F
class LeNet[D <: BFloat16 | Float32: Default]
extends nn.Module:
val conv1 = register(nn.Conv2d(1, 6, 5))
val conv2 = register(nn.Conv2d(6, 16, 5))
val fc1 = register(nn.Linear(16 * 4 * 4, 120))
val fc2 = register(nn.Linear(120, 84))
val fc3 = register(nn.Linear(84, 10))
def apply(i: Tensor[D]): Tensor[D] =
var x = F.maxPool2d(F.relu(conv1(i)), (2, 2))
x = F.maxPool2d(F.relu(conv2(x)), 2)
x = x.view(-1, 16 * 4 * 4)
x = F.relu(fc1(x))
x = F.relu(fc2(x))
x = fc3(x)
x
PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
?
Cat
Dog
Cat
Storch aims to offer the same core features as PyTorch:
From numbers
From images (via the Scrimage library)
image: ImmutableImage = Image [width=325, height=291, type=5] imageTensor: Tensor[Float32] = tensor dtype=float32, shape=[3, 291, 325], device=CPU [[[1.0000, 1.0000, 0.9961, ..., 0.9961, 1.0000, 1.0000], [0.9922, 0.9843, 0.9961, ..., 1.0000, 0.9961, 1.0000], [0.9961, 0.9961, 0.8941, ..., 0.8471, 0.9647, 0.9922], ..., [1.0000, 1.0000, 0.7216, ..., 0.5020, 0.8431, 0.9961], [1.0000, 0.9882, 0.6000, ..., 0.5255, 0.8431, 1.0000], [1.0000, 1.0000, 0.9961, ..., 0.9843, 1.0000, 0.9882]], [[1.0000, 1.0000, 1.0000, ..., 0.9882, 0.9961, 0.9961], [0.9922, 0.9922, 1.0000, ..., 1.0000, 0.9882, 0.9922], [1.0000, 1.0000, 0.8980, ..., 0.8431, 0.9608, 0.9843], ..., [0.9922, 1.0000, 0.7137, ..., 0.5020, 0.8431, 0.9922], [1.0000, 0.9765, 0.5882, ..., 0.5255, 0.8392, 1.0000], [1.0000, 1.0000, 0.9922, ..., 0.9843, 1.0000, 0.9922]], [[0.9843, 0.9843, 1.0000, ..., 0.9922, 1.0000, 1.0000], [0.9843, 0.9804, 1.0000, ..., 0.9922, 0.9922, 0.9961], [1.0000, 1.0000, 0.9059, ..., 0.8353, 0.9529, 0.9882], ..., [0.9647, 0.9647, 0.6627, ..., 0.4549, 0.8118, 0.9725], [0.9725, 0.9412, 0.5529, ..., 0.4941, 0.8196, 0.9843], [0.9922, 0.9843, 0.9765, ..., 0.9529, 0.9804, 0.9725]]]
m: Tensor[Float32] = tensor dtype=float32, shape=[3, 3], device=CPU [[0.4963, 0.7682, 0.0885], [0.1320, 0.3074, 0.6341], [0.4901, 0.8964, 0.4556]]
m: Tensor[Float32] = tensor dtype=float32, shape=[3, 3], device=CUDA [[0.3990, 0.5167, 0.0249], [0.9401, 0.9459, 0.7967], [0.4150, 0.8203, 0.2290]]
t: Tensor[Float32] = tensor dtype=float32, shape=[3], device=CUDA [0.0000, 1.0000, 2.0000]
res15: Tensor[Float32] = tensor dtype=float32, shape=[3], device=CUDA [0.5665, 2.5393, 1.2783]
Let’s learn the identity function!
Let’s learn the identity function!
Let’s see how far we’re off
loss: Tensor[Float32] = tensor dtype=float32, shape=[1], device=CPU [33.1874]
Can we change the weights to make the loss go down?
We compute the derivative (gradient) of our loss function
We compute the derivative (gradient) of our loss function
Backpropagation (compute the gradients w.r.t. the weights)
res62_1: Tensor[torch.Tensor[torch.Float32 | torch.Undefined]] = tensor dtype=float32, shape=[1], device=CPU [-39.0000]
for step <- 1 to 20 do
val x = torch.randint(low=0, high=100, size=Seq(1), dtype=int32) // input
val y = x // expected output
val z = model(x) // predicted output
val loss = (z - y).abs // error/loss
loss.backward() // compute gradient
noGrad { // disable gradient tracking locally
// update weight based on gradient
model.weights -= (model.weights.grad * 0.001f)
model.weights.grad.zero()
}
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
1 | 44 | 08 | 36 | 0.1880 | -44.00 | 0.2320 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
2 | 39 | 09 | 30 | 0.2320 | -39.00 | 0.2710 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
3 | 33 | 09 | 24 | 0.2710 | -33.00 | 0.3040 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
4 | 60 | 18 | 42 | 0.3040 | -60.00 | 0.3640 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
5 | 63 | 23 | 40 | 0.3640 | -63.00 | 0.4270 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
6 | 79 | 34 | 45 | 0.4270 | -79.00 | 0.5060 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
7 | 27 | 14 | 13 | 0.5060 | -27.00 | 0.5330 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
8 | 03 | 02 | 01 | 0.5330 | -3.00 | 0.5360 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
9 | 97 | 52 | 45 | 0.5360 | -97.00 | 0.6330 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
10 | 83 | 53 | 30 | 0.6330 | -83.00 | 0.7160 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
11 | 01 | 01 | 00 | 0.7160 | -1.00 | 0.7170 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
12 | 66 | 47 | 19 | 0.7170 | -66.00 | 0.7830 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
13 | 56 | 44 | 12 | 0.7830 | -56.00 | 0.8390 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
14 | 99 | 83 | 16 | 0.8390 | -99.00 | 0.9380 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
15 | 78 | 73 | 05 | 0.9380 | -78.00 | 1.0160 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
16 | 76 | 77 | 01 | 1.0160 | 76.00 | 0.9400 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
17 | 56 | 53 | 03 | 0.9400 | -56.00 | 0.9960 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
18 | 68 | 68 | 00 | 0.9960 | -68.00 | 1.0640 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
19 | 94 | 100 | 06 | 1.0640 | 94.00 | 0.9700 |
step | x | z | loss | old weight | gradient | new weight |
---|---|---|---|---|---|---|
20 | 33 | 32 | 01 | 0.9700 | -33.00 | 1.0030 |
That’s how all AIs are trained!
import torch.nn.functional as F
import torch.nn
class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module:
val conv1 = register(nn.Conv2d(1, 6, 5))
val conv2 = register(nn.Conv2d(6, 16, 5))
val fc1 = register(nn.Linear(16 * 4 * 4, 120))
val fc2 = register(nn.Linear(120, 84))
val fc3 = register(nn.Linear(84, 10))
def apply(i: Tensor[D]): Tensor[D] =
var x = F.maxPool2d(F.relu(conv1(i)), (2, 2))
x = F.maxPool2d(F.relu(conv2(x)), 2)
x = x.view(-1, 16 * 4 * 4)
x = F.relu(fc1(x))
x = F.relu(fc2(x))
x = fc3(x)
x
/** ResNet architecture implementations */
object resnet:
/** 3x3 convolution with padding */
def conv3x3[D <: BFloat16 | Float32 | Float64: Default](
inPlanes: Int,
outPlanes: Int,
stride: Int = 1,
groups: Int = 1,
dilation: Int = 1
): Conv2d[D] =
Conv2d[D](
inPlanes,
outPlanes,
kernelSize = 3,
stride = stride,
padding = dilation,
groups = groups,
bias = false,
dilation = dilation
)
/** 1x1 convolution */
def conv1x1[D <: FloatNN: Default](inPlanes: Int, outPlanes: Int, stride: Int = 1): Conv2d[D] =
Conv2d[D](inPlanes, outPlanes, kernelSize = 1, stride = stride, bias = false)
sealed abstract class BlockBuilder:
val expansion: Int
def apply[D <: BFloat16 | Float32 | Float64: Default](
inplanes: Int,
planes: Int,
stride: Int = 1,
downsample: Option[TensorModule[D]] = None,
groups: Int = 1,
baseWidth: Int = 64,
dilation: Int = 1,
normLayer: (Int => HasWeight[D] & TensorModule[D])
): TensorModule[D] = this match
case BasicBlock =>
new BasicBlock(inplanes, planes, stride, downsample, groups, baseWidth, dilation, normLayer)
case Bottleneck =>
new Bottleneck(inplanes, planes, stride, downsample, groups, baseWidth, dilation, normLayer)
object BasicBlock extends BlockBuilder:
override val expansion: Int = 1
object Bottleneck extends BlockBuilder:
override val expansion: Int = 4
class BasicBlock[D <: BFloat16 | Float32 | Float64: Default](
inplanes: Int,
planes: Int,
stride: Int = 1,
downsample: Option[TensorModule[D]] = None,
groups: Int = 1,
baseWidth: Int = 64,
dilation: Int = 1,
normLayer: => (Int => TensorModule[D])
) extends TensorModule[D] {
import BasicBlock.expansion
if groups != 1 || baseWidth != 64 then
throw new IllegalArgumentException("BasicBlock only supports groups=1 and baseWidth=64")
if dilation > 1 then throw new NotImplementedError("Dilation > 1 not supported in BasicBlock")
// Both conv1 and downsample layers downsample the input when stride != 1
val conv1 = register(conv3x3[D](inplanes, planes, stride))
val bn1 = register(normLayer(planes))
val relu = register(ReLU(inplace = true))
val conv2 = register(conv3x3[D](planes, planes))
val bn2 = register(normLayer(planes))
downsample.foreach(downsample => register(downsample)(using Name("downsample")))
def apply(x: Tensor[D]): Tensor[D] =
var identity = x
var out = conv1(x)
out = bn1(out)
out = relu(out)
out = conv2(out)
out = bn2(out)
downsample.foreach { downsample =>
identity = downsample(x)
}
out += identity
out = relu(out)
out
override def toString(): String = getClass().getSimpleName()
}
/** Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
* while original implementation places the stride at the first 1x1 convolution(self.conv1)
* according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
* This variant is also known as ResNet V1.5 and improves accuracy according to
* https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
*/
class Bottleneck[D <: BFloat16 | Float32 | Float64: Default](
inplanes: Int,
planes: Int,
stride: Int = 1,
downsample: Option[TensorModule[D]] = None,
groups: Int = 1,
baseWidth: Int = 64,
dilation: Int = 1,
normLayer: (Int => HasWeight[D] & TensorModule[D])
) extends TensorModule[D]:
import Bottleneck.expansion
val width = (planes * (baseWidth / 64.0)).toInt * groups
// Both self.conv2 and self.downsample layers downsample the input when stride != 1
val conv1 = register(conv1x1(inplanes, width))
val bn1 = register(normLayer(width))
val conv2 = register(conv3x3(width, width, stride, groups, dilation))
val bn2 = register(normLayer(width))
val conv3 = register(conv1x1(width, planes * expansion))
val bn3 = register(normLayer(planes * expansion))
val relu = register(ReLU(inplace = true))
downsample.foreach(downsample => register(downsample)(using Name("downsample")))
def apply(x: Tensor[D]): Tensor[D] =
var identity = x
var out = conv1(x)
out = bn1(out)
out = relu(out)
out = conv2(out)
out = bn2(out)
out = relu(out)
out = conv3(out)
out = bn3(out)
downsample.foreach { downsample =>
identity = downsample(x)
}
out += identity
out = relu(out)
out
override def toString(): String = getClass().getSimpleName()
class ResNet[D <: BFloat16 | Float32 | Float64](
block: BlockBuilder,
layers: Seq[Int],
numClasses: Int = 1000,
zeroInitResidual: Boolean = false,
groups: Int = 1,
widthPerGroup: Int = 64,
// each element in the tuple indicates if we should replace
// the 2x2 stride with a dilated convolution instead
replaceStrideWithDilation: (Boolean, Boolean, Boolean) = (false, false, false)
)(using Default[D])(
normLayer: (Int => HasWeight[D] & TensorModule[D]) =
(numFeatures => BatchNorm2d[D](numFeatures))
) extends Module {
var inplanes = 64
var dilation = 1
val baseWidth = widthPerGroup
val conv1 = register(Conv2d(3, inplanes, kernelSize = 7, stride = 2, padding = 3, bias = false))
val bn1 = register(normLayer(inplanes))
val relu = register(ReLU(inplace = true))
val maxpool = register(MaxPool2d(kernelSize = 3, stride = Some(2), padding = 1))
val layer1 = register(makeLayer(block, 64, layers(0)))
val layer2 = register(
makeLayer(block, 128, layers(1), stride = 2, dilate = replaceStrideWithDilation(0))
)
val layer3 = register(
makeLayer(block, 256, layers(2), stride = 2, dilate = replaceStrideWithDilation(1))
)
val layer4 = register(
makeLayer(block, 512, layers(3), stride = 2, dilate = replaceStrideWithDilation(2))
)
val avgpool = register(AdaptiveAvgPool2d((1, 1)))
val fc = register(Linear(512 * block.expansion, numClasses))
for (m <- modules)
m match
case m: Conv2d[?] =>
nn.init.kaimingNormal_(m.weight, mode = Mode.FanOut, nonlinearity = NonLinearity.ReLU)
case m: BatchNorm2d[?] =>
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
case m: nn.GroupNorm[?] =>
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
case _ =>
// Zero-initialize the last BN in each residual branch,
// so that the residual branch starts with zeros, and each residual block behaves like an identity.
// This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zeroInitResidual then
for (m <- modules)
m match
case m: Bottleneck[?] if m.bn3.weight.dtype != DType.undefined =>
nn.init.constant_(m.bn3.weight, 0)
case m: Bottleneck[?] if m.bn2.weight.dtype != DType.undefined =>
nn.init.constant_(m.bn2.weight, 0)
private def makeLayer(
block: BlockBuilder,
planes: Int,
blocks: Int,
stride: Int = 1,
dilate: Boolean = false
): Sequential[D] = {
var downsample: Option[TensorModule[D]] = None
val previous_dilation = this.dilation
var _stride: Int = stride
if dilate then
this.dilation *= stride
_stride = 1
if _stride != 1 || inplanes != planes * block.expansion then
downsample = Some(
Sequential(
conv1x1(inplanes, planes * block.expansion, _stride),
normLayer(planes * block.expansion)
)
)
var layers = Vector[TensorModule[D]]()
layers = layers :+ block(
this.inplanes,
planes,
_stride,
downsample,
this.groups,
this.baseWidth,
previous_dilation,
normLayer
)
inplanes = planes * block.expansion
for (_ <- 1 until blocks)
layers = layers :+
block(
this.inplanes,
planes,
groups = this.groups,
baseWidth = this.baseWidth,
dilation = this.dilation,
normLayer = normLayer
)
Sequential(layers*)
}
private inline def forwardImpl(_x: Tensor[D]): Tensor[D] =
var x = conv1(_x)
x = bn1(x)
x = relu(x)
x = maxpool(x)
x = layer1(x)
x = layer2(x)
x = layer3(x)
x = layer4(x)
x = avgpool(x)
x = x.flatten(1)
fc(x)
def apply(x: Tensor[D]): Tensor[D] = forwardImpl(x)
}
private val weightsBaseUrl =
"https://github.com/sbrunk/storch/releases/download/pretrained-weights/"
case class Weights(
url: String,
transforms: Presets.ImageClassification
)
abstract class ResNetFactory:
def apply[D <: BFloat16 | Float32 | Float64: Default](numClasses: Int = 1000): ResNet[D]
val DEFAULT: Weights
/** ResNet-18 from [Deep Residual Learning for Image
* Recognition](https://arxiv.org/pdf/1512.03385.pdf).
*/
object ResNet18 extends ResNetFactory:
def apply[D <: BFloat16 | Float32 | Float64: Default](numClasses: Int = 1000) =
ResNet[D](BasicBlock, Seq(2, 2, 2, 2), numClasses)()
val IMAGENET1K_V1 = Weights(
url = weightsBaseUrl + "resnet18-f37072fd.pth",
Presets.ImageClassification(cropSize = 224)
)
val DEFAULT = IMAGENET1K_V1
/** ResNet-34 from [Deep Residual Learning for Image
* Recognition](https://arxiv.org/pdf/1512.03385.pdf).
*/
object ResNet34 extends ResNetFactory:
def apply[D <: BFloat16 | Float32 | Float64: Default](numClasses: Int = 1000) =
ResNet[D](BasicBlock, Seq(3, 4, 6, 3), numClasses = numClasses)()
val IMAGENET1K_V1 = Weights(
url = weightsBaseUrl + "resnet34-b627a593.pth",
transforms = Presets.ImageClassification(cropSize = 224)
)
val DEFAULT = IMAGENET1K_V1
/** ResNet-50 from [Deep Residual Learning for Image
* Recognition](https://arxiv.org/pdf/1512.03385.pdf)
*/
object ResNet50 extends ResNetFactory:
def apply[D <: BFloat16 | Float32 | Float64: Default](numClasses: Int = 1000) =
ResNet(Bottleneck, Seq(3, 4, 6, 3), numClasses = numClasses)()
val IMAGENET1K_V1 = Weights(
url = weightsBaseUrl + "resnet50-0676ba61.pth",
transforms = Presets.ImageClassification(cropSize = 224)
)
val IMAGENET1K_V2 = Weights(
url = weightsBaseUrl + "resnet50-11ad3fa6.pth",
transforms = Presets.ImageClassification(cropSize = 224, resizeSize = 232)
)
val DEFAULT = IMAGENET1K_V2
/** ResNet-101 from [Deep Residual Learning for Image
* Recognition](https://arxiv.org/pdf/1512.03385.pdf)
*/
object ResNet101 extends ResNetFactory:
def apply[D <: BFloat16 | Float32 | Float64: Default](numClasses: Int = 1000) =
ResNet(Bottleneck, Seq(3, 4, 23, 3), numClasses = numClasses)()
val IMAGENET1K_V1 = Weights(
url = weightsBaseUrl + "resnet101-63fe2227.pth",
transforms = Presets.ImageClassification(cropSize = 224)
)
val IMAGENET1K_V2 = Weights(
url = weightsBaseUrl + "resnet101-cd907fc2.pth",
transforms = Presets.ImageClassification(cropSize = 224, resizeSize = 232)
)
val DEFAULT = IMAGENET1K_V2
/** ResNet-152 from [Deep Residual Learning for Image
* Recognition](https://arxiv.org/pdf/1512.03385.pdf)
*/
object ResNet152 extends ResNetFactory:
def apply[D <: BFloat16 | Float32 | Float64: Default](numClasses: Int = 1000) =
ResNet(Bottleneck, Seq(3, 8, 36, 3), numClasses = numClasses)()
val IMAGENET1K_V1 = Weights(
url = weightsBaseUrl + "resnet152-394f9c45.pth",
transforms = Presets.ImageClassification(cropSize = 224)
)
val IMAGENET1K_V2 = Weights(
url = weightsBaseUrl + "resnet152-f82ba261.pth",
transforms = Presets.ImageClassification(cropSize = 224, resizeSize = 232)
)
val DEFAULT = IMAGENET1K_V2
enum ResNetVariant(val factory: ResNetFactory):
case ResNet18 extends ResNetVariant(resnet.this.ResNet18)
case ResNet34 extends ResNetVariant(resnet.this.ResNet34)
case ResNet50 extends ResNetVariant(resnet.this.ResNet50)
case ResNet101 extends ResNetVariant(resnet.this.ResNet101)
case ResNet152 extends ResNetVariant(resnet.this.ResNet152)
def dataLoader: Iterator[(Tensor[Float32], Tensor[Int64])] = ???
val model: ResNet[Float32] = ResNet34(numClasses = 2)
val optimizer = Adam(model.parameters, lr = 1e-5)
val lossFn = nn.loss.CrossEntropyLoss()
for epoch <- 1 to numEpochs do // loop through data
for (input, label) <- dataLoader do // loop through batches
optimizer.zeroGrad() // reset optimizer
// ensure timely deallocation
Using.resource(new PointerScope()) { _ =>
val prediction = model(input)
// compute loss (how good was the prediction)
val loss = lossFn(prediction, label)
loss.backward() // compute the gradients of the loss
optimizer.step() // update weights based on the gradients
}
def dataLoader: Iterator[(Tensor[Float32], Tensor[Int64])] = ???
val model: ResNet[Float32] = ResNet34(numClasses = 2)
val optimizer = Adam(model.parameters, lr = 1e-5)
val lossFn = nn.loss.CrossEntropyLoss()
for epoch <- 1 to numEpochs do // loop through data
for (input, label) <- dataLoader do // loop through batches
optimizer.zeroGrad() // reset optimizer
// ensure timely deallocation
Using.resource(new PointerScope()) { _ =>
val prediction = model(input)
// compute loss (how good was the prediction)
val loss = lossFn(prediction, label)
loss.backward() // compute the gradients of the loss
optimizer.step() // update weights based on the gradients
}
def dataLoader: Iterator[(Tensor[Float32], Tensor[Int64])] = ???
val model: ResNet[Float32] = ResNet34(numClasses = 2)
val optimizer = Adam(model.parameters, lr = 1e-5)
val lossFn = nn.loss.CrossEntropyLoss()
for epoch <- 1 to numEpochs do // loop through data
for (input, label) <- dataLoader do // loop through batches
optimizer.zeroGrad() // reset optimizer
// ensure timely deallocation
Using.resource(new PointerScope()) { _ =>
val prediction = model(input)
// compute loss (how good was the prediction)
val loss = lossFn(prediction, label)
loss.backward() // compute the gradients of the loss
optimizer.step() // update weights based on the gradients
}
PyTorch
Storch
DType.scala
/** Type of the output tensor based on PyTorch type promotion rules
*
* This is a type-level implementation of the PyTorch op data type promotion rules via match types.
*/
type Promoted[T <: DType, U <: DType] <: DType = (T, U) match
case (T, T) => T
case (U, U) => U
// ...
case (Int64, U) => U
case (T, Int64) => T
case (Float16, BFloat16) | (BFloat16, Float16) => Float32
case (Float16, U) => U
case (T, Float16) => T
// ...
Tensor.scala
DType.scala
/** Promoted type for tensor division */
type Div[T <: DType, U <: DType] <: DType = (T, U) match
case (BitwiseNN, BitwiseNN) => Float32
case _ => Promoted[T, U]
/** Promoted type for elementwise tensor sum */
type Sum[D <: DType] <: DType = D match
case BitwiseNN => Int64
case D => D
Tensor.scala
🫵