Skip to content

Commit c4b553e

Browse files
committed
Add loss functions
1 parent 0a6b6b6 commit c4b553e

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/maths/loss_functions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Module loss_functions. Implements basic loss functions
2+
geared towards using PyTorch
3+
4+
"""
5+
6+
import torch
7+
8+
9+
def mse(returns: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
10+
11+
value_error = returns - values
12+
loss = value_error.pow(2).mul(0.5).mean()
13+
return loss

0 commit comments

Comments
 (0)