Skip to content

Commit b3acd8d

Browse files
Add ArgMaxSampler public init (#392)
1 parent 870a89b commit b3acd8d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ public struct GenerateParameters: Sendable {
126126

127127
/// Sampler that uses `argMax` (most likely) to sample the logits.
128128
public struct ArgMaxSampler: LogitSampler {
129-
public func sample(logits: MLX.MLXArray) -> MLX.MLXArray {
129+
public init() {}
130+
131+
public func sample(logits: MLXArray) -> MLXArray {
130132
argMax(logits, axis: -1)
131133
}
132134
}

0 commit comments

Comments
 (0)