Skip to content

Commit 8f05b22

Browse files
authored
Add algorithm ReLU_function (#569)
1 parent 887b6bd commit 8f05b22

File tree

2 files changed

+134
-0
lines changed

2 files changed

+134
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
using Algorithms.Numeric;
2+
using NUnit.Framework;
3+
using System;
4+
5+
namespace Algorithms.Tests.Numeric;
6+
7+
[TestFixture]
8+
public static class ReluTests
9+
{
10+
// Tolerance for floating-point comparisons
11+
private const double Tolerance = 1e-9;
12+
13+
// --- SCALAR TESTS (Relu.Compute(double)) ---
14+
15+
[TestCase(0.0, 0.0)]
16+
[TestCase(1.0, 1.0)]
17+
[TestCase(-1.0, 0.0)]
18+
[TestCase(5.0, 5.0)]
19+
[TestCase(-5.0, 0.0)]
20+
public static void ReluFunction_Scalar_ReturnsCorrectValue(double input, double expected)
21+
{
22+
var result = Relu.Compute(input);
23+
Assert.That(result, Is.EqualTo(expected).Within(Tolerance));
24+
}
25+
26+
[Test]
27+
public static void ReluFunction_Scalar_HandlesLimitsAndNaN()
28+
{
29+
// Positive infinity stays +Infinity, negative infinity becomes 0, NaN propagates
30+
Assert.That(RelUComputePositiveInfinity(), Is.EqualTo(double.PositiveInfinity));
31+
Assert.That(RelUComputeNegativeInfinity(), Is.EqualTo(0.0).Within(Tolerance));
32+
Assert.That(RelUComputeNaN(), Is.NaN);
33+
34+
static double RelUComputePositiveInfinity() => Relu.Compute(double.PositiveInfinity);
35+
static double RelUComputeNegativeInfinity() => Relu.Compute(double.NegativeInfinity);
36+
static double RelUComputeNaN() => Relu.Compute(double.NaN);
37+
}
38+
39+
[TestCase(100.0)]
40+
[TestCase(0.0001)]
41+
[TestCase(-100.0)]
42+
public static void ReluFunction_Scalar_ResultIsNonNegative(double input)
43+
{
44+
var result = Relu.Compute(input);
45+
Assert.That(result, Is.GreaterThanOrEqualTo(0.0));
46+
}
47+
48+
// --- VECTOR TESTS (Relu.Compute(double[])) ---
49+
50+
[Test]
51+
public static void ReluFunction_Vector_ReturnsCorrectValues()
52+
{
53+
var input = new[] { 0.0, 1.0, -2.0 };
54+
var expected = new[] { 0.0, 1.0, 0.0 };
55+
56+
var result = Relu.Compute(input);
57+
58+
Assert.That(result, Is.EqualTo(expected).Within(Tolerance));
59+
}
60+
61+
[Test]
62+
public static void ReluFunction_Vector_HandlesLimitsAndNaN()
63+
{
64+
var input = new[] { double.PositiveInfinity, 0.0, double.NaN };
65+
var result = Relu.Compute(input);
66+
67+
Assert.That(result.Length, Is.EqualTo(input.Length));
68+
Assert.That(result[0], Is.EqualTo(double.PositiveInfinity));
69+
Assert.That(result[1], Is.EqualTo(0.0).Within(Tolerance));
70+
Assert.That(result[2], Is.NaN);
71+
}
72+
73+
// --- EXCEPTION TESTS ---
74+
75+
[Test]
76+
public static void ReluFunction_Vector_ThrowsOnNullInput()
77+
{
78+
double[]? input = null;
79+
Assert.Throws<ArgumentNullException>(() => Relu.Compute(input!));
80+
}
81+
82+
[Test]
83+
public static void ReluFunction_Vector_ThrowsOnEmptyInput()
84+
{
85+
var input = Array.Empty<double>();
86+
Assert.Throws<ArgumentException>(() => Relu.Compute(input));
87+
}
88+
}

Algorithms/Numeric/Relu.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
namespace Algorithms.Numeric;
2+
3+
/// <summary>
4+
/// Implementation of the Rectified Linear Unit (ReLU) function.
5+
/// ReLU is defined as: ReLU(x) = max(0, x).
6+
/// It is commonly used as an activation function in neural networks.
7+
/// </summary>
8+
public static class Relu
9+
{
10+
/// <summary>
11+
/// Compute the Rectified Linear Unit (ReLU) for a single value.
12+
/// </summary>
13+
/// <param name="input">The input real number.</param>
14+
/// <returns>The output real number (>= 0).</returns>
15+
public static double Compute(double input)
16+
{
17+
return Math.Max(0.0, input);
18+
}
19+
20+
/// <summary>
21+
/// Compute the Rectified Linear Unit (ReLU) element-wise for a vector.
22+
/// </summary>
23+
/// <param name="input">The input vector of real numbers.</param>
24+
/// <returns>The output vector where each element is max(0, input[i]).</returns>
25+
public static double[] Compute(double[] input)
26+
{
27+
if (input is null)
28+
{
29+
throw new ArgumentNullException(nameof(input));
30+
}
31+
32+
if (input.Length == 0)
33+
{
34+
throw new ArgumentException("Array is empty.");
35+
}
36+
37+
var output = new double[input.Length];
38+
39+
for (var i = 0; i < input.Length; i++)
40+
{
41+
output[i] = Math.Max(0.0, input[i]);
42+
}
43+
44+
return output;
45+
}
46+
}

0 commit comments

Comments
 (0)