diff --git a/basalt/__init__.mojo b/basalt/__init__.mojo index 137eec4c..73d84f1d 100644 --- a/basalt/__init__.mojo +++ b/basalt/__init__.mojo @@ -5,3 +5,4 @@ from basalt.utils.collection import Collection alias dtype = DType.float32 alias nelts = 2 * simdwidthof[dtype]() alias seed = 42 +alias epsilon = 1e-12 diff --git a/basalt/autograd/attributes.mojo b/basalt/autograd/attributes.mojo index 9be18227..3e1c3b3d 100644 --- a/basalt/autograd/attributes.mojo +++ b/basalt/autograd/attributes.mojo @@ -1,4 +1,5 @@ from collections import Optional, OptionalReg +from utils.static_tuple import StaticTuple from basalt.nn.tensor import Tensor, TensorShape, MAX_RANK from basalt.utils.bytes import Bytes, scalar_to_bytes, bytes_to_scalar @@ -45,9 +46,8 @@ struct AttributeVector(Sized, Stringable, CollectionElement): var attributes: StaticTuple[Attribute, MAX_ATTRS] var size: Int - @always_inline("nodebug") fn __init__(inout self, *attributes: Attribute): - self.attributes = StaticTuple[Attribute, MAX_ATTRS]() + self.attributes = StaticTuple[Attribute, MAX_ATTRS](Attribute("", "")) self.size = len(attributes) for i in range(self.size): self.attributes[i] = attributes[i] @@ -67,7 +67,10 @@ struct AttributeVector(Sized, Stringable, CollectionElement): return self.attributes[i] return None - @always_inline("nodebug") + fn append(inout self, attribute: Attribute): + self.attributes[self.size] = attribute + self.size += 1 + fn __str__(self) -> String: var s: String = "[" for i in range(self.size): @@ -85,7 +88,6 @@ struct Attribute(Stringable, CollectionElement): var type: AttributeType var size: Int - @always_inline("nodebug") fn __init__(inout self, name: String, value: String): self.data_shape = StaticIntTuple[MAX_RANK]() self.name = Bytes[MAX_NAME_CHARS](name) @@ -93,7 +95,6 @@ struct Attribute(Stringable, CollectionElement): self.type = AttributeType.STRING self.size = len(value) - @always_inline("nodebug") fn __init__(inout self, name: String, value: TensorShape): self.data_shape = StaticIntTuple[MAX_RANK]() self.name = Bytes[MAX_NAME_CHARS](name) @@ -104,7 +105,6 @@ struct Attribute(Stringable, CollectionElement): for i in range(self.size): self.data_shape[i] = value._shape[i] - @always_inline("nodebug") fn __init__[N: Int](inout self, name: String, value: StaticIntTuple[N]): constrained[N < MAX_RANK, "Attribute rank must be less than MAX_RANK."]() @@ -115,9 +115,8 @@ struct Attribute(Stringable, CollectionElement): self.size = N for i in range(self.size): - self.data[i] = value[i] + self.data_shape[i] = value[i] - @always_inline("nodebug") fn __init__[dtype: DType](inout self, name: String, value: Scalar[dtype]): constrained[dtype.is_numeric(), "Attribute value must be numeric."]() @@ -127,46 +126,38 @@ struct Attribute(Stringable, CollectionElement): self.type = AttributeType(dtype) self.size = 1 - @always_inline("nodebug") fn __init__(inout self, name: String, value: Int): self.__init__(name, Int64(value)) self.data_shape[0] = 1 - @always_inline("nodebug") fn __init__(inout self, name: String, value: FloatLiteral): self.__init__(name, Float64(value)) self.data_shape[0] = 1 - @always_inline("nodebug") fn __str__(self) -> String: return "Attribute(" + str(self.name) + ", " + "..." + ")" - @always_inline("nodebug") fn to_string(self) -> String: return str(self.data) - @always_inline("nodebug") fn to_shape(self) -> TensorShape: return TensorShape(rank=self.size, shape=self.data_shape) - @always_inline("nodebug") fn to_static[N: Int](self) -> StaticIntTuple[N]: constrained[N < MAX_RANK, "Attribute rank must be less than MAX_RANK."]() var result = StaticIntTuple[N]() for i in range(N): - result[i] = int(self.data[i]) + result[i] = int(self.data_shape[i]) return result - @always_inline("nodebug") fn to_scalar[dtype: DType](self) -> Scalar[dtype]: constrained[dtype.is_numeric(), "Attribute value must be numeric."]() return bytes_to_scalar[dtype](self.data) - @always_inline("nodebug") fn to_int(self) -> Int: return int(self.to_scalar[DType.int64]()) diff --git a/basalt/autograd/graph.mojo b/basalt/autograd/graph.mojo index 23537740..cd28b095 100644 --- a/basalt/autograd/graph.mojo +++ b/basalt/autograd/graph.mojo @@ -11,7 +11,6 @@ from basalt import seed, dtype from basalt import Tensor, TensorShape -@value struct Graph: var inputs: List[Symbol] var params: ParamDict @@ -28,41 +27,42 @@ struct Graph: self.loss_out = None self.symbol_count = 0 - fn input(inout self, shape: TensorShape, trainable: Bool = False) -> Symbol: - var inp = Symbol(self.symbol_count, dtype, shape, trainable) - self.inputs.append(inp) - self.symbol_count += 1 - return inp + fn __moveinit__(inout self, owned other: Graph): + self.inputs = other.inputs^ + self.params = other.params^ + self.nodes = other.nodes^ + self.outputs = other.outputs^ + self.loss_out = other.loss_out + self.symbol_count = other.symbol_count - fn param( - inout self, shape: TensorShape, init: Param, trainable: Bool = True - ) -> Symbol: - var param_id = Symbol(self.symbol_count, dtype, shape, trainable) - self.params.put(param_id, init) + fn create_symbol(inout self, shape: TensorShape, data: Optional[Param] = None, trainable: Bool = False, is_input: Bool = False) -> Symbol: + var symbol = Symbol(self.symbol_count, dtype, shape, trainable) self.symbol_count += 1 - return param_id + + if is_input: + self.inputs.append(symbol) + else: + if data is not None: + self.params.put(symbol, data.value()[]) + else: + self.params.put(symbol) + + return symbol + + fn input(inout self, shape: TensorShape, trainable: Bool = False) -> Symbol: + return self.create_symbol(shape, trainable=trainable, is_input=True) + + fn param(inout self, shape: TensorShape, init: Param, trainable: Bool = True) -> Symbol: + return self.create_symbol(shape, init, trainable) fn param(inout self, shape: TensorShape, trainable: Bool = True) -> Symbol: - var param_id = Symbol(self.symbol_count, dtype, shape, trainable) - self.params.put(param_id) - self.symbol_count += 1 - return param_id + return self.create_symbol(shape, trainable=trainable) fn scalar(inout self, value: Scalar[dtype]) -> Symbol: - var scal = Param(value) - var scalar_id = Symbol( - self.symbol_count, dtype, TensorShape(1), trainable=False - ) - self.params.put(scalar_id, scal) - self.symbol_count += 1 - return scalar_id + return self.create_symbol(TensorShape(1), Param(value), trainable=False) fn constant(inout self, shape: TensorShape, data: List[Scalar[dtype]]) -> Symbol: - var cst = Param(data) - var constant_id = Symbol(self.symbol_count, dtype, shape, trainable=False) - self.params.put(constant_id, cst) - self.symbol_count += 1 - return constant_id + return self.create_symbol(shape, Param(data), trainable=False) fn out(inout self, symbol: Symbol): self.outputs.append(symbol) @@ -77,14 +77,15 @@ struct Graph: attributes: AttributeVector = AttributeVector(), ) -> Symbol: var res_shape = static_result_shape(op, operands, attributes) - var res = Symbol( - self.symbol_count, dtype, res_shape, self.result_trainable(operands) - ) + var res = Symbol(self.symbol_count, dtype, res_shape, self.result_trainable(operands)) self.symbol_count += 1 var inputs = List[Symbol]() + inputs.reserve(len(operands)) + for operand in operands: inputs.append(operand) + self.nodes.append(Node(op, inputs, List[Symbol](res), attributes)) return res @@ -95,8 +96,7 @@ struct Graph: operand_2: Float64, attributes: AttributeVector = AttributeVector(), ) -> Symbol: - var operand_2_symbol = self.scalar(operand_2) - return self.op(op, operand_1, operand_2_symbol, attributes=attributes) + return self.op(op, operand_1, self.scalar(operand_2), attributes=attributes) fn op( inout self, @@ -105,43 +105,43 @@ struct Graph: operand_2: Symbol, attributes: AttributeVector = AttributeVector(), ) -> Symbol: - var operand_1_symbol = self.scalar(operand_1) - return self.op(op, operand_1_symbol, operand_2, attributes=attributes) + return self.op(op, self.scalar(operand_1), operand_2, attributes=attributes) + + fn create_symbols(inout self, shapes: List[TensorShape], trainable: Bool = False) -> List[Symbol]: + var symbols = List[Symbol]() + symbols.reserve(len(shapes)) + + for shape in shapes: + symbols.append(Symbol(self.symbol_count, dtype, shape[], trainable)) + self.symbol_count += 1 + + return symbols + + fn add_node(inout self, op: OP, inputs: List[Symbol], outputs: List[Symbol], attributes: AttributeVector): + self.nodes.append(Node(op, inputs, outputs, attributes)) - # Dynamic ops fn concat(inout self, *operands: Symbol, dim: Int = 0) -> Symbol: - # NOTE: Concat could fit into g.op() given a different static_result_shape is called var attributes = AttributeVector(Attribute("dim", dim)) - var res_shape = dynamic_result_shape(OP.CONCAT, operands, attributes)[0] - var res = Symbol( - self.symbol_count, dtype, res_shape, self.result_trainable(operands) - ) - self.symbol_count += 1 + var res_symbols = self.create_symbols(List[TensorShape](res_shape), self.result_trainable(operands)) - var inputs = List[Symbol]() + var operand_list = List[Symbol]() + operand_list.reserve(len(operands)) for operand in operands: - inputs.append(operand) - self.nodes.append(Node(OP.CONCAT, inputs, List[Symbol](res), attributes)) - return res + operand_list.append(operand) + + self.add_node(OP.CONCAT, operand_list, res_symbols, attributes) + return res_symbols[0] fn split( inout self, operand: Symbol, sections: List[Int], dim: Int = 0 ) -> List[Symbol]: - var attributes = AttributeVector( - Attribute("sections", TensorShape(sections)), Attribute("dim", dim) - ) + var attributes = AttributeVector(Attribute("sections", TensorShape(sections)), Attribute("dim", dim)) var res_shapes = dynamic_result_shape(OP.SPLIT, operand, attributes) var trainable = self.result_trainable(operand) - - var results = List[Symbol]() - for i in range(len(res_shapes)): - var symbol = Symbol(self.symbol_count, dtype, res_shapes[i], trainable) - results.append(symbol) - self.symbol_count += 1 - - self.nodes.append(Node(OP.SPLIT, List[Symbol](operand), results, attributes)) - return results + var result_symbols = self.create_symbols(res_shapes, trainable) + self.add_node(OP.SPLIT, List[Symbol](operand), result_symbols, attributes) + return result_symbols @staticmethod fn result_trainable(operands: VariadicList[Symbol]) -> Bool: diff --git a/basalt/autograd/ops/basics.mojo b/basalt/autograd/ops/basics.mojo index 3a4a5ab2..74662ca9 100644 --- a/basalt/autograd/ops/basics.mojo +++ b/basalt/autograd/ops/basics.mojo @@ -1,11 +1,15 @@ -from math import add, sub, mul, div, log, exp +from math import log, exp from algorithm import vectorize from memory import memcpy +from utils.numerics import isinf from basalt import Tensor, TensorShape from basalt.nn.tensor import MAX_RANK from basalt.utils.tensorutils import * from basalt.autograd.attributes import Attribute, AttributeVector +from basalt.autograd.ops.matmul import dot, dot_transpose_t1, dot_transpose_t2 +from basalt.utils.math_util import add, sub, mul, div + """ Implement forward and backward operations for basic tensor manipulations. @@ -315,7 +319,9 @@ struct POW: # d(x^y) / dx = y * x^(y-1) # d(x^y) / dy = sum( x^y * log(x) ) var res_grad: Tensor[dtype] - var a = int(t2[0]) + var a = t2[0] + + alias epsilon = 1e-12 @parameter if tensor_id == 0: @@ -323,20 +329,23 @@ struct POW: @parameter fn vec_pow_bw_x[nelts: Int](i: Int): - res_grad.store[nelts]( - i, a * (t1.load[nelts](i) ** (a - 1)) * ug.load[nelts](i) - ) + res_grad.store[nelts](i, a * ((t1.load[nelts](i) + epsilon) ** (a - 1)) * ug.load[nelts](i)) vectorize[vec_pow_bw_x, nelts](t1_shape.num_elements()) else: + # Gradient of the exponent res_grad = Tensor[dtype](t2_shape) # t2_shape == TensorShape(1) @parameter fn vec_pow_bw_y[nelts: Int](i: Int): + # the case when the value passed to log is 0.0 + var temp_log = log(t1.load[nelts](i)) + var temp_log_is_inf = isinf(temp_log) + temp_log = temp_log_is_inf.select(0, temp_log) res_grad[0] += ( (t1.load[nelts](i) ** a) - * log(t1.load[nelts](i)) + * temp_log * ug.load[nelts](i) ).reduce_add() diff --git a/basalt/autograd/ops/conv.mojo b/basalt/autograd/ops/conv.mojo index 3e1a18c0..774eb031 100644 --- a/basalt/autograd/ops/conv.mojo +++ b/basalt/autograd/ops/conv.mojo @@ -1,9 +1,7 @@ from basalt import Tensor, TensorShape from basalt.autograd.attributes import AttributeVector -from basalt.utils.tensorutils import dot, dot_transpose_t1, dot_transpose_t2 from algorithm import parallelize, vectorize, tile -from math import divmod from utils.loop import unroll diff --git a/basalt/autograd/ops/dynamics.mojo b/basalt/autograd/ops/dynamics.mojo index 0f304efb..5c30493c 100644 --- a/basalt/autograd/ops/dynamics.mojo +++ b/basalt/autograd/ops/dynamics.mojo @@ -33,7 +33,7 @@ struct CONCAT: fn forward[attributes: AttributeVector]( inputs: List[Symbol], outputs: List[Symbol], - parameters: Parameters, + inout parameters: Parameters, ): alias dim = attributes["dim"].value().to_int() if attributes["dim"] else 0 var n_chunks = Self.calc_chunks(inputs[0].shape, dim) @@ -58,7 +58,7 @@ struct CONCAT: fn backward[input_id: Int, attributes: AttributeVector]( inputs: List[Symbol], outputs: List[Symbol], - parameters: Parameters, + inout parameters: Parameters, ) -> Tensor[dtype]: alias dim = attributes["dim"].value().to_int() if attributes["dim"] else 0 var n_chunks = Self.calc_chunks(inputs[0].shape, dim) @@ -113,7 +113,7 @@ struct SPLIT: fn forward[attributes: AttributeVector]( inputs: List[Symbol], outputs: List[Symbol], - parameters: Parameters, + inout parameters: Parameters, ): alias dim = attributes["dim"].value().to_int() if attributes["dim"] else 0 alias sections = attributes["sections"].value().to_shape() @@ -139,7 +139,7 @@ struct SPLIT: fn backward[input_id: Int, attributes: AttributeVector]( inputs: List[Symbol], outputs: List[Symbol], - parameters: Parameters, + inout parameters: Parameters, ) -> Tensor[dtype]: alias dim = attributes["dim"].value().to_int() if attributes["dim"] else 0 alias sections = attributes["sections"].value().to_shape() diff --git a/basalt/autograd/ops/matmul.mojo b/basalt/autograd/ops/matmul.mojo new file mode 100644 index 00000000..bc2cf2ba --- /dev/null +++ b/basalt/autograd/ops/matmul.mojo @@ -0,0 +1,175 @@ +from basalt.utils.tensorutils import transpose_2D +from algorithm import vectorize, parallelize + + +@always_inline +fn calculate_block[ + M: Int, N: Int, K: Int, BLOCK_M: Int, BLOCK_N: Int, nelts: Int +]( + res: DTypePointer[dtype], + t1: DTypePointer[dtype], + t2: DTypePointer[dtype], + bm: Int, + bn: Int, +): + # Compute tile + var acc = stack_allocation[BLOCK_M * BLOCK_N, dtype]() + memset_zero[dtype](acc, BLOCK_M * BLOCK_N) + + for k in range(K): + + @parameter + for m in range(BLOCK_M): + + @parameter + fn inner_n[nelts: Int](n: Int): + acc.store[width=nelts]( + m * BLOCK_N + n, + SIMD[dtype, nelts] + .splat(t1[(bm + m) * K + k]) + .fma( + t2.load[width=nelts](k * N + (bn + n)), + acc.load[width=nelts](m * BLOCK_N + n), + ), + ) + + vectorize[inner_n, nelts](BLOCK_N) + + # Store tile + for m in range(BLOCK_M): + + @parameter + fn vec_store[nelts: Int](n: Int): + res.store[width=nelts]( + (bm + m) * N + (bn + n), acc.load[width=nelts](m * BLOCK_N + n) + ) + + vectorize[vec_store, nelts](BLOCK_N) + + +@parameter +@always_inline +fn dot[ + t1_shape: TensorShape, t2_shape: TensorShape +](inout res: Tensor[dtype], t1: Tensor[dtype], t2: Tensor[dtype]): + dot[t1_shape, t2_shape](res.data(), t1.data(), t2.data()) + + +@parameter +@always_inline +fn dot[ + t1_shape: TensorShape, t2_shape: TensorShape +](res: DTypePointer[dtype], t1: DTypePointer[dtype], t2: DTypePointer[dtype]): + alias M = t1_shape[0] # t1[0] + alias K = t1_shape[1] # t1[1], t2[0] + alias N = t2_shape[1] # t2[1] + + # simdwidthof[dtype]() = 8 for float32 + alias nelts = simdwidthof[dtype]() + alias BLOCK_N = 8 * 2 + alias BLOCK_M = 6 + alias THREADS = 6 # num_logical_cores() + + alias BLOCK_N_REMAINDER = N % BLOCK_N + alias BLOCK_M_REMAINDER = M % BLOCK_M + + @parameter + fn bm_par(m_outer: Int): + var bm = m_outer * BLOCK_M + + for n_outer in range(0, N // BLOCK_N): + var bn = n_outer * BLOCK_N + + calculate_block[M, N, K, BLOCK_M, BLOCK_N, nelts](res, t1, t2, bm, bn) + + # Handle the remainder of N + @parameter + if BLOCK_N_REMAINDER > 0: + var bn = N - BLOCK_N_REMAINDER + + calculate_block[M, N, K, BLOCK_M, BLOCK_N_REMAINDER, nelts]( + res, t1, t2, bm, bn + ) + + parallelize[bm_par](M // BLOCK_M, M // BLOCK_M) + + # Handle the remainder of M + @parameter + if BLOCK_M_REMAINDER > 0: + var bm = M - BLOCK_M_REMAINDER + + for n_outer in range(0, N // BLOCK_N): + var bn = n_outer * BLOCK_N + + calculate_block[M, N, K, BLOCK_M_REMAINDER, BLOCK_N, nelts]( + res, t1, t2, bm, bn + ) + + # Handle corner remainder + @parameter + if BLOCK_N_REMAINDER > 0: + var bn = N - BLOCK_N_REMAINDER + + calculate_block[M, N, K, BLOCK_M_REMAINDER, BLOCK_N_REMAINDER, nelts]( + res, t1, t2, bm, bn + ) + + +fn dot_transpose_t2[ + A_shape: TensorShape, B_shape: TensorShape +](inout C: DTypePointer[dtype], A: DTypePointer[dtype], B: DTypePointer[dtype]): + dot[A_shape, TensorShape(B_shape[1], B_shape[0])](C, A, transpose_2D[B_shape](B)) + + +fn dot_transpose_t2[ + A_shape: TensorShape, B_shape: TensorShape +](inout C: Tensor[dtype], A: Tensor[dtype], B: Tensor[dtype]): + memset_zero[dtype](C.data(), C.num_elements()) + + dot[A_shape, TensorShape(B_shape[1], B_shape[0])](C, A, transpose_2D[B_shape](B)) + + # @parameter + # fn calc_row(i: Int): + # for j in range(B_shape[0]): + + # @parameter + # fn calc_row_A_B[nelts: Int](k: Int): + # var A_pos = i * A.dim(1) + k + # var B_pos = j * A.dim(1) + k + # var t_new_pos = i * C.dim(1) + j + + # C[t_new_pos] += ( + # A.load[nelts](A_pos) * B.load[nelts](B_pos) + # ).reduce_add() + + # vectorize[calc_row_A_B, nelts, size=A_shape[1]]() + + # parallelize[calc_row](A_shape[0], 1) + + +fn dot_transpose_t1[ + A_shape: TensorShape, B_shape: TensorShape +](inout C: Tensor[dtype], A: Tensor[dtype], B: Tensor[dtype]): + memset_zero[dtype](C.data(), C.num_elements()) + + dot[TensorShape(A_shape[1], A_shape[0]), B_shape](C, transpose_2D[A_shape](A), B) + + # @parameter + # fn calc_row(i: Int): + # for j in range(A_shape[0]): + + # @parameter + # fn calc_row_t_new_B[nelts: Int](k: Int): + # var A_pos = j * A.dim(1) + i + # var B_pos = j * B.dim(1) + k + # var t_new_pos = i * C.dim(1) + k + + # C.store[nelts]( + # t_new_pos, + # C.load[nelts](t_new_pos) + # + A[A_pos] * B.load[nelts](B_pos), + # ) + + # vectorize[calc_row_t_new_B, nelts, size=B_shape[1]]() + + # parallelize[calc_row](A_shape[1], 1) diff --git a/basalt/autograd/ops/mlops.mojo b/basalt/autograd/ops/mlops.mojo index 08699199..30a61e8c 100644 --- a/basalt/autograd/ops/mlops.mojo +++ b/basalt/autograd/ops/mlops.mojo @@ -1,9 +1,11 @@ from algorithm import vectorize, parallelize -from math import exp, pow, max, min, abs -from math.limit import min_finite, max_finite +from math import exp, floor, ceil +from utils.numerics import min_finite, max_finite +from utils.static_tuple import StaticTuple from basalt import Tensor, TensorShape from basalt.utils.tensorutils import elwise_transform +from basalt.utils.itertools import product from basalt.autograd.attributes import Attribute, AttributeVector @@ -52,7 +54,7 @@ struct SIGMOID: vectorize[vec_sigmoid_bw, nelts](ug_shape.num_elements()) - return res_grad ^ + return res_grad^ struct RELU: @@ -100,7 +102,62 @@ struct RELU: vectorize[vec_relu_bw, nelts](ug_shape.num_elements()) - return res_grad ^ + return res_grad^ + + +struct LEAKYRELU: + @staticmethod + fn result_shape(t1_shape: TensorShape) -> TensorShape: + return t1_shape + + @staticmethod + fn forward[ + t1_shape: TensorShape, + attributes: AttributeVector, + ](inout res: Tensor[dtype], t1: Tensor[dtype]): + """Forward operation of leaky_relu.""" + + fn leaky_relu[ + type: DType, + simd_width: Int, + ](x: SIMD[type, simd_width]) -> SIMD[type, simd_width]: + var negative_slope = attributes["negative_slope"].value().to_scalar[ + type + ]() + return (x > 0).select(x, x * negative_slope) + + elwise_transform[leaky_relu](res, t1) + + @staticmethod + fn backward[ + ug_shape: TensorShape, + t1_shape: TensorShape, + attributes: AttributeVector, + ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: + """Backward operation of leaky_relu.""" + + @always_inline + fn leaky_relu_bw[ + type: DType, simd_width: Int + ](x: SIMD[type, simd_width]) -> SIMD[type, simd_width]: + var negative_slope = attributes["negative_slope"].value().to_scalar[ + type + ]() + + return (x > 0).select[type](1, negative_slope) + + var res_grad = Tensor[dtype](ug_shape) + + @parameter + fn vec_leaky_relu_bw[nelts: Int](idx: Int): + res_grad.store[nelts]( + idx, + leaky_relu_bw(t1.load[nelts](idx)) * ug.load[nelts](idx), + ) + + vectorize[vec_leaky_relu_bw, nelts](ug_shape.num_elements()) + + return res_grad^ struct TANH: @@ -146,7 +203,7 @@ struct TANH: vectorize[vec_tanh_bw, nelts](ug_shape.num_elements()) - return res_grad ^ + return res_grad^ struct CLIP: @@ -164,12 +221,12 @@ struct CLIP: alias min_attr = attributes["min"] alias max_attr = attributes["max"] - var min_val = min_attr.value().to_scalar[dtype]() if min_attr else min_finite[ + var min_val = min_attr.value().to_scalar[ dtype - ]() - var max_val = max_attr.value().to_scalar[dtype]() if max_attr else max_finite[ + ]() if min_attr else min_finite[dtype]() + var max_val = max_attr.value().to_scalar[ dtype - ]() + ]() if max_attr else max_finite[dtype]() @parameter fn vec_clip[nelts: Int](i: Int): @@ -187,12 +244,12 @@ struct CLIP: alias min_attr = attributes["min"] alias max_attr = attributes["max"] - var min_val = min_attr.value().to_scalar[dtype]() if min_attr else min_finite[ + var min_val = min_attr.value().to_scalar[ dtype - ]() - var max_val = max_attr.value().to_scalar[dtype]() if max_attr else max_finite[ + ]() if min_attr else min_finite[dtype]() + var max_val = max_attr.value().to_scalar[ dtype - ]() + ]() if max_attr else max_finite[dtype]() var res_grad = Tensor[dtype](t_shape) @@ -201,17 +258,21 @@ struct CLIP: var val = t.load[nelts](i) res_grad.store[nelts]( i, - ((val >= min_val) * (val <= max_val)).select(ug.load[nelts](i), 0), + ((val >= min_val) * (val <= max_val)).select( + ug.load[nelts](i), 0 + ), ) vectorize[vec_clip_bw, nelts, size = t_shape.num_elements()]() - return res_grad ^ + return res_grad^ struct SQUEEZE: @staticmethod - fn result_shape(t1_shape: TensorShape, attributes: AttributeVector) -> TensorShape: + fn result_shape( + t1_shape: TensorShape, attributes: AttributeVector + ) -> TensorShape: var dim = attributes["dims"] var dims_to_squeeze = dim.value().to_shape() if dim else TensorShape() @@ -239,12 +300,14 @@ struct SQUEEZE: ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: var res_grad = Tensor[dtype](t1_shape) memcpy(res_grad.data(), ug.data(), ug.num_elements()) - return res_grad ^ + return res_grad^ struct UNSQUEEZE: @staticmethod - fn result_shape(t1_shape: TensorShape, attributes: AttributeVector) -> TensorShape: + fn result_shape( + t1_shape: TensorShape, attributes: AttributeVector + ) -> TensorShape: var dim = attributes["dims"] var dims_to_squeeze = dim.value().to_shape() if dim else TensorShape() @@ -276,7 +339,7 @@ struct UNSQUEEZE: ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: var res_grad = Tensor[dtype](t1_shape) memcpy(res_grad.data(), ug.data(), ug.num_elements()) - return res_grad ^ + return res_grad^ struct SLICE: @@ -285,7 +348,7 @@ struct SLICE: # Adjust negative indices & ensure they are within bounds. var s = slice if slice >= 0 else dim_size + slice return max(min(s, dim_size), 0) - + @staticmethod fn default_starts(shape: TensorShape) -> List[Int]: var starts = List[Int]() @@ -306,7 +369,7 @@ struct SLICE: for i in range(shape.rank()): steps.append(1) return steps^ - + @staticmethod fn default_axes(shape: TensorShape) -> List[Int]: # NOTE: axes can't be negative @@ -316,38 +379,55 @@ struct SLICE: return axes^ @staticmethod - fn result_shape(t1_shape: TensorShape, attributes: AttributeVector) -> TensorShape: + fn result_shape( + t1_shape: TensorShape, attributes: AttributeVector + ) -> TensorShape: # NOTE: Starts and ends have to be of the same size # NOTE: If axes not provided, starts and ends have to be of the same size as t1_shape var starts = attributes["starts"].value().to_shape() var ends = attributes["ends"].value().to_shape() - var steps = attributes["steps"].value().to_shape() if attributes["steps"] else Self.default_steps(starts) - var axes = attributes["axes"].value().to_shape() if attributes["axes"] else Self.default_axes(t1_shape) + var steps = attributes["steps"].value().to_shape() if attributes[ + "steps" + ] else Self.default_steps(starts) + var axes = attributes["axes"].value().to_shape() if attributes[ + "axes" + ] else Self.default_axes(t1_shape) var new_shape = t1_shape for i in range(starts.rank()): var axis = axes[i] - new_shape[axis] = len(range( - start = Self.adjust_boundary(starts[i], t1_shape[axis]), - end = Self.adjust_boundary(ends[i], t1_shape[axis]), - step = steps[i] - )) + new_shape[axis] = len( + range( + start=Self.adjust_boundary(starts[i], t1_shape[axis]), + end=Self.adjust_boundary(ends[i], t1_shape[axis]), + step=steps[i], + ) + ) return new_shape @staticmethod - fn reorder_positions[id: Int](original: TensorShape, axes: TensorShape, t1_shape: TensorShape) -> List[Int]: + fn reorder_positions[ + id: Int + ](original: TensorShape, axes: TensorShape, t1_shape: TensorShape) -> List[ + Int + ]: # Reorder the starts (id=0), ends (id=1) or steps (id=2) to match the order of the axes var updated: List[Int] @parameter - if id == 0: updated = Self.default_starts(t1_shape) - elif id == 1: updated = Self.default_ends(t1_shape) - else: updated = Self.default_steps(t1_shape) - + if id == 0: + updated = Self.default_starts(t1_shape) + elif id == 1: + updated = Self.default_ends(t1_shape) + else: + updated = Self.default_steps(t1_shape) + for i in range(axes.rank()): var axis = axes[i] - updated[axis] = original[i] if id == 2 else Self.adjust_boundary(original[i], t1_shape[axis]) + updated[axis] = original[i] if id == 2 else Self.adjust_boundary( + original[i], t1_shape[axis] + ) return updated^ @@ -360,12 +440,12 @@ struct SLICE: steps: List[Int], starts: List[Int], ends: List[Int], - backward_op: Bool = False + backward_op: Bool = False, ]( inout res: Tensor[dtype], t1: Tensor[dtype], last_dims: Int, - position: Int, + position: Int, last_position: Int, idx: Int, idx_original: Int, @@ -374,7 +454,9 @@ struct SLICE: alias t1_strides = original_shape.strides() var idx_temp = idx - var idx_original_temp = starts[position] * t1_strides[position] + idx_original + var idx_original_temp = starts[position] * t1_strides[ + position + ] + idx_original if position == last_position + 1: # Work on the last dimensions @@ -382,37 +464,50 @@ struct SLICE: alias stride = t1_strides[position] * steps[position] @parameter - fn v_slice[nelts: Int](k : Int): - + fn v_slice[nelts: Int](k: Int): @parameter if not backward_op: + @parameter if steps[position] == 1: - res.store[nelts](idx_temp + k, t1.load[nelts](idx_original_temp)) + res.store[nelts]( + idx_temp + k, t1.load[nelts](idx_original_temp) + ) else: res.store[nelts]( idx_temp + k, - t1.data().offset(idx_original_temp).simd_strided_load[nelts](stride) + t1.data() + .offset(idx_original_temp) + .simd_strided_load[nelts](stride), ) else: + @parameter if steps[position] == 1: res.store[nelts](idx_original_temp, t1.load[nelts](idx_temp + k)) else: - res.data().offset(idx_original_temp).simd_strided_store[nelts]( + res.data().offset(idx_original_temp).simd_strided_store[width=nelts]( t1.load[nelts](idx_temp + k), stride ) - + idx_original_temp += stride * nelts vectorize[v_slice, nelts](last_dims) - return + return for _ in range(shape[position]): - Self.recursive_iters_slice[shape, original_shape, steps, starts, ends, backward_op]( - res, t1, last_dims, position + 1, last_position, idx_temp, idx_original_temp + Self.recursive_iters_slice[ + shape, original_shape, steps, starts, ends, backward_op + ]( + res, + t1, + last_dims, + position + 1, + last_position, + idx_temp, + idx_original_temp, ) idx_temp += strides[position] @@ -425,10 +520,10 @@ struct SLICE: steps: List[Int], starts: List[Int], ends: List[Int], - backward_op: Bool = False + backward_op: Bool = False, ](inout res: Tensor[dtype], t1: Tensor[dtype]): alias strides = original_shape.strides() - + # Get the dimensions for vectorization var last_dims = 1 var positions_to_skip = 0 @@ -439,7 +534,7 @@ struct SLICE: positions_to_skip += 1 if starts[i] != 0 or ends[i] != original_shape[i] or steps[i] != 1: break - + # Get the dimensions for the first loop var first_dims = 1 var start_position = 0 @@ -450,31 +545,46 @@ struct SLICE: start_position += 1 var middle_dims = res_shape.num_elements() // last_dims // first_dims - + @parameter fn p_slice(i: Int): Self.recursive_iters_slice[ res_shape, original_shape, steps, starts, ends, backward_op ]( - res, t1, last_dims, start_position, res_shape.rank() - 1 - positions_to_skip, - i * middle_dims * last_dims, i * strides[start_position - 1] + res, + t1, + last_dims, + start_position, + res_shape.rank() - 1 - positions_to_skip, + i * middle_dims * last_dims, + i * strides[start_position - 1], ) parallelize[p_slice](first_dims) - + @staticmethod fn forward[ t1_shape: TensorShape, attributes: AttributeVector, ](inout res: Tensor[dtype], t1: Tensor[dtype]): - alias axes = attributes["axes"].value().to_shape() if attributes["axes"] else Self.default_axes(t1_shape) - alias starts = Self.reorder_positions[0](attributes["starts"].value().to_shape(), axes, t1_shape) - alias ends = Self.reorder_positions[1](attributes["ends"].value().to_shape(), axes, t1_shape) - alias steps = Self.reorder_positions[2](attributes["steps"].value().to_shape(), axes, t1_shape) if attributes["steps"] else Self.default_steps(t1_shape) + alias axes = attributes["axes"].value().to_shape() if attributes[ + "axes" + ] else Self.default_axes(t1_shape) + alias starts = Self.reorder_positions[0]( + attributes["starts"].value().to_shape(), axes, t1_shape + ) + alias ends = Self.reorder_positions[1]( + attributes["ends"].value().to_shape(), axes, t1_shape + ) + alias steps = Self.reorder_positions[2]( + attributes["steps"].value().to_shape(), axes, t1_shape + ) if attributes["steps"] else Self.default_steps(t1_shape) alias res_shape = Self.result_shape(t1_shape, attributes) - Self.slice_kernel[res_shape, t1_shape, steps, starts, ends, False](res, t1) + Self.slice_kernel[res_shape, t1_shape, steps, starts, ends, False]( + res, t1 + ) @staticmethod fn backward[ @@ -482,13 +592,378 @@ struct SLICE: t1_shape: TensorShape, attributes: AttributeVector = AttributeVector(), ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: - alias axes = attributes["axes"].value().to_shape() if attributes["axes"] else Self.default_axes(t1_shape) - alias starts = Self.reorder_positions[0](attributes["starts"].value().to_shape(), axes, t1_shape) - alias ends = Self.reorder_positions[1](attributes["ends"].value().to_shape(), axes, t1_shape) - alias steps = Self.reorder_positions[2](attributes["steps"].value().to_shape(), axes, t1_shape) if attributes["steps"] else Self.default_steps(t1_shape) + alias axes = attributes["axes"].value().to_shape() if attributes[ + "axes" + ] else Self.default_axes(t1_shape) + alias starts = Self.reorder_positions[0]( + attributes["starts"].value().to_shape(), axes, t1_shape + ) + alias ends = Self.reorder_positions[1]( + attributes["ends"].value().to_shape(), axes, t1_shape + ) + alias steps = Self.reorder_positions[2]( + attributes["steps"].value().to_shape(), axes, t1_shape + ) if attributes["steps"] else Self.default_steps(t1_shape) var res_grad = Tensor[dtype](t1_shape) Self.slice_kernel[ug_shape, t1_shape, steps, starts, ends, True](res_grad, ug) - return res_grad ^ \ No newline at end of file + return res_grad ^ + + +struct INDEX: + @staticmethod + fn adjust_boundary(slice: Int, dim_size: Int) -> Int: + # Adjust negative indices & ensure they are within bounds. + var s = slice if slice >= 0 else dim_size + slice + return max(min(s, dim_size), 0) + + @staticmethod + fn to_indeces(shape: TensorShape, attrs: AttributeVector) -> List[List[Int]]: + var SLICE_LITERALS = List[StringLiteral]("dim_0s", "dim_1s", "dim_2s", "dim_3s", "dim_4s", "dim_5s", "dim_6s", "dim_7s") + var INDEX_LITERALS = List[StringLiteral]("dim_0i", "dim_1i", "dim_2i", "dim_3i", "dim_4i", "dim_5i", "dim_6i", "dim_7i") + + var indeces = List[List[Int]]() + for dim in range(shape.rank()): + var temp = List[Int]() + + # Option 1: Slice + if attrs[SLICE_LITERALS[dim]]: + var slice = attrs[SLICE_LITERALS[dim]].value().to_shape() + var step = slice[2] if slice.rank() == 3 else 1 + for i in range( + start=Self.adjust_boundary(slice[0], shape[dim]), + end=Self.adjust_boundary(slice[1], shape[dim]), + step=step + ): + temp.append(i) + + # Option 2: Indeces + elif attrs[INDEX_LITERALS[dim]]: + var indeces = attrs[INDEX_LITERALS[dim]].value().to_shape() + for i in range(indeces.rank()): + temp.append(indeces[i]) + + # All indeces + else: + for i in range(shape[dim]): + temp.append(i) + + indeces.append(temp) + + return indeces ^ + + @staticmethod + fn result_shape(shape: TensorShape, attrs: AttributeVector) -> TensorShape: + var indeces = Self.to_indeces(shape, attrs) + var new_shape = List[Int]() + for i in range(shape.rank()): + new_shape.append(len(indeces[i])) + return TensorShape(new_shape) + + @staticmethod + fn map_indeces[ + nelts: Int, + strides: TensorShape, + indeces: List[List[Int]], + ](idx: Int) -> SIMD[DType.int64, nelts]: + alias indeces_product = product(indeces) + + var temp = SIMD[DType.int64, nelts]() + for i in range(idx, idx + nelts): + var comb = indeces_product[i] + var flat_index = 0 + + for dim in range(len(comb)): + flat_index += comb[dim] * strides[dim] + + temp[i % nelts] = flat_index + + return temp + + @staticmethod + fn forward[ + t1_shape: TensorShape, + attributes: AttributeVector, + ](inout res: Tensor[dtype], t1: Tensor[dtype]): + alias indeces = Self.to_indeces(t1_shape, attributes) + alias strides = t1_shape.strides() + alias total_length = len(product(indeces)) + + @parameter + fn vec_index[nelts: Int](i: Int): + + res.store[nelts](i, + t1.data().gather(Self.map_indeces[nelts, strides, indeces](i)) + ) + + vectorize[vec_index, nelts](total_length) + + + @staticmethod + fn backward[ + ug_shape: TensorShape, + t1_shape: TensorShape, + attributes: AttributeVector = AttributeVector(), + ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: + alias indeces = Self.to_indeces(t1_shape, attributes) + alias strides = t1_shape.strides() + alias total_length = len(product(indeces)) + + var res_grad = Tensor[dtype](t1_shape) + + @parameter + fn vec_index[nelts: Int](i: Int): + + var offset = Self.map_indeces[nelts, strides, indeces](i) + + # res_grad.data().scatter( + # offset, + # res_grad.data().gather(offset) + ug.load[nelts](i), + # ) + # BUG: Edge case in vectorization: + # When the offset = [0, 2, 4, 0] and ug = [1, 1, 1, 1] + # It doesn't scatter to index 0 twice as it should be: res_grad[0] += 1 + 1 + + # Workaround + var u = ug.load[nelts](i) + for j in range(nelts): + res_grad[int(offset[j])] += u[j] + + vectorize[vec_index, nelts](total_length) + + return res_grad^ + + +struct UPSAMPLE: + @staticmethod + fn result_shape(t1_shape: TensorShape, attributes: AttributeVector) -> TensorShape: + var scales = attributes["scales"].value().to_shape() + var mode = attributes["mode"].value().to_string() + + var new_shape = List[Int]() + for i in range(0, t1_shape.rank()): + if i < 2: + new_shape.append(t1_shape[i]) + else: + new_shape.append(t1_shape[i] * scales[i - 2]) + + return TensorShape(new_shape) + + @staticmethod + fn recursive_iter[pos_shape: Int, shape: TensorShape, scales: TensorShape](inout res: Tensor[dtype], t1: Tensor[dtype], strides_res: StaticIntTuple[8], index_t1: Int, index_res: Int): + alias end_pos = shape.rank() - 1 + alias strides = shape.strides() + + @parameter + if pos_shape >= end_pos: + @parameter + fn v_iter[nelts: Int](i: Int): + var values = t1.load[nelts](index_t1 + i) + + var offset_res = index_res + i * scales[end_pos - 2] + for j in range(nelts * scales[pos_shape - 2]): + var temp = j // scales[pos_shape - 2] + + res[offset_res + j] = values[temp] + + vectorize[v_iter, nelts](shape[pos_shape]) + + return + else: + for i in range(shape[pos_shape] * scales[pos_shape - 2]): + var temp_i = i // scales[pos_shape - 2] + var temp_index_t1 = temp_i * strides[pos_shape] + index_t1 + var temp_index_res = i * strides_res[pos_shape] + index_res + + Self.recursive_iter[pos_shape + 1, shape, scales](res, t1, strides_res, temp_index_t1, temp_index_res) + + @staticmethod + fn forward[ + t1_shape: TensorShape, + attributes: AttributeVector, + ](inout res: Tensor[dtype], t1: Tensor[dtype]): + # Input is [N, C, D in, H in, W in], N is batch size and C is number of channels. Ranks 3-D, 4-D or 5-D tensors (only works on the spatial dimensions). + alias scales = attributes["scales"].value().to_shape() # Has to match spatial input dims (the last dimensions D, H and W) + alias mode = attributes["mode"].value().to_string() + # alias align_corners = attributes["align_corners"].value().to_bool() if attributes["align_corners"] else false + + @parameter + fn get_coordination_mode() -> String: + if mode == "linear" or mode == "bilinear": + return "half_pixel" + else: + return "asymmetric" + alias coordination_transforamtion = get_coordination_mode() + + alias strides = t1_shape.strides() + var strides_res = res.strides() + + var res_shape = res.shape() + + alias first_loop = t1_shape[0] * t1_shape[1] + + @always_inline + fn pos_asymmetric(pos: Int, scale: Int) -> Int: + return pos // scale + + @always_inline + fn pos_half_pixel(pos: Int, scale: Int) -> Float64: + return max(0.0, (pos + 0.5) / scale - 0.5) + + + @parameter + @always_inline + fn get_value_interpolate[size: Int]( + indeces_t1: StaticTuple[Float64, size] + ) -> SIMD[t1.dtype, 1]: + @parameter + if mode == "nearest": + var indeces_t1_sum = indeces_t1[0] + @parameter + for i in range(1, size): + indeces_t1_sum += indeces_t1[i] * strides[i + 1] + + return t1[int(indeces_t1_sum)] + elif mode == "linear": + var t1_pos_floor = floor(indeces_t1[1]) + var t1_pos_ceil = min(ceil(indeces_t1[1]), t1_shape[2] - 1) + + var v1 = t1[int(indeces_t1[0]) + int(t1_pos_floor)] + var v2 = t1[int(indeces_t1[0]) + int(t1_pos_ceil)] + + return v1 + (v2 - v1) * (indeces_t1[1] - t1_pos_floor) + elif mode == "bilinear": + var t1_pos_floor_y = floor(indeces_t1[1]) + var t1_pos_ceil_y = min(ceil(indeces_t1[1]), t1_shape[2] - 1) + + var t1_pos_floor_x = floor(indeces_t1[2]) + var t1_pos_ceil_x = min(ceil(indeces_t1[2]), t1_shape[3] - 1) + + var v1 = t1[int(indeces_t1[0]) + int(t1_pos_floor_y) * strides[2] + int(t1_pos_floor_x) * strides[3]] + var v2 = t1[int(indeces_t1[0]) + int(t1_pos_floor_y) * strides[2] + int(t1_pos_ceil_x) * strides[3]] + var v3 = t1[int(indeces_t1[0]) + int(t1_pos_ceil_y) * strides[2] + int(t1_pos_floor_x) * strides[3]] + var v4 = t1[int(indeces_t1[0]) + int(t1_pos_ceil_y) * strides[2] + int(t1_pos_ceil_x) * strides[3]] + + var wy = indeces_t1[1] - t1_pos_floor_y + var wx = indeces_t1[2] - t1_pos_floor_x + + var top_interp = v1 + (v2 - v1) * wx + var bottom_interp = v3 + (v4 - v3) * wx + + return top_interp + (bottom_interp - top_interp) * wy + else: + return 0 + + @always_inline + fn get_t1_position( + pos: Int, scale: Int, dim: Int + ) -> Float64: + @parameter + if coordination_transforamtion == "asymmetric": + return pos_asymmetric(pos, scale) + elif coordination_transforamtion == "half_pixel": + return pos_half_pixel(pos, scale) + else: + return 0 + + # it is possble to use gather, the only problem is to be able to create a simd arange (vectorized if it is with a for loop it is the same probably). (And from tests it seems to be slower, maybe because i do a lot of casts and because the arange of positions is not vectorized) + @parameter + fn p_iter(i: Int): + var offset_t1 = i * strides[1] + var offset_res = i * strides_res[1] + + @parameter + if t1_shape.rank() == 3: + var positions_t1 = StaticTuple[Float64, 2](0) + var positions_res = StaticIntTuple[2](0) + + positions_res[0] = offset_res + positions_t1[0] = offset_t1 + + @parameter + fn v_iter[nelts: Int](j: Int): + positions_res[1] = j + + var index_res = positions_res[0] + positions_res[1] + var values = res.load[nelts](index_res) + + for k in range(nelts): + positions_t1[1] = get_t1_position(j + k, scales[scales.rank() - 1], 0) + + values[k] = get_value_interpolate(positions_t1) + + res.store[nelts](index_res, values) + + + vectorize[v_iter, nelts](res_shape[res.rank() - 1]) + elif t1_shape.rank() == 4: + var positions_t1 = StaticTuple[Float64, 3](0) + var positions_res = StaticIntTuple[3](0) + + positions_res[0] = offset_res + positions_t1[0] = offset_t1 + + for j in range(res_shape[2]): + positions_res[1] = j * strides_res[2] + positions_t1[1] = get_t1_position(j, scales[0], 0) + + @parameter + fn v_iter_1[nelts: Int](k: Int): + positions_res[2] = k + + var index_res = positions_res[0] + positions_res[1] + positions_res[2] + var values = res.load[nelts](index_res) + + for l in range(nelts): + positions_t1[2] = get_t1_position(k + l, scales[scales.rank() - 1], 1) + + values[l] = get_value_interpolate(positions_t1) + + res.store[nelts](index_res, values) + + vectorize[v_iter_1, nelts](res_shape[res.rank() - 1]) + + elif t1_shape.rank() == 5: + var positions_t1 = StaticTuple[Float64, 4](0) + var positions_res = StaticIntTuple[4](0) + + positions_res[0] = offset_res + positions_t1[0] = offset_t1 + + for j in range(res.shape()[2]): + positions_res[1] = j * strides_res[2] + positions_t1[1] = get_t1_position(j, scales[0], 0) + for k in range(res.shape()[3]): + positions_res[2] = k * strides_res[3] + positions_t1[2] = get_t1_position(k, scales[1], 1) + + @parameter + fn v_iter_2[nelts: Int](l: Int): + positions_res[3] = l + + var index_res = positions_res[0] + positions_res[1] + positions_res[2] + positions_res[3] + var values = res.load[nelts](index_res) + + for m in range(nelts): + positions_t1[3] = get_t1_position(l + m, scales[scales.rank() - 1], 2) + + values[m] = get_value_interpolate(positions_t1) + + res.store[nelts](index_res, values) + + vectorize[v_iter_2, nelts](res_shape[res.rank() - 1]) + else: + # Error + pass + + parallelize[p_iter](first_loop) + + @staticmethod + fn backward[ + ug_shape: TensorShape, + t1_shape: TensorShape, + attributes: AttributeVector = AttributeVector(), + ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: + return t1 diff --git a/basalt/autograd/ops/ops.mojo b/basalt/autograd/ops/ops.mojo index 71982706..b870f786 100644 --- a/basalt/autograd/ops/ops.mojo +++ b/basalt/autograd/ops/ops.mojo @@ -15,7 +15,7 @@ from .basics import ( TRANSPOSE, FMA, ) -from .mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE, SLICE +from .mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE, SLICE, INDEX, UPSAMPLE, LEAKYRELU from .dynamics import CONCAT, SPLIT from .conv import CONV2D from .pool import MAXPOOL2D @@ -61,6 +61,9 @@ struct OP(Stringable): alias CONCAT = OP(23, "CONCAT", dynamic=True) alias SPLIT = OP(24, "SPLIT", dynamic=True) alias SLICE = OP(25, "SLICE") + alias INDEX = OP(26, "INDEX") + alias UPSAMPLE = OP(27, "UPSAMPLE") + alias LEAKYRELU = OP(28, "LEAKYRELU") var id: UInt8 var name: Bytes[16] @@ -87,10 +90,16 @@ fn static_result_shape( if len(operands) == 1: return static_result_shape(op, operands[0].shape, attributes) elif len(operands) == 2: - return static_result_shape(op, operands[0].shape, operands[1].shape, attributes) + return static_result_shape( + op, operands[0].shape, operands[1].shape, attributes + ) elif len(operands) == 3: return static_result_shape( - op, operands[0].shape, operands[1].shape, operands[2].shape, attributes + op, + operands[0].shape, + operands[1].shape, + operands[2].shape, + attributes, ) else: print("Error: Invalid number of operands") @@ -121,6 +130,8 @@ fn static_result_shape( return SIGMOID.result_shape(t1_shape) elif op == OP.RELU: return RELU.result_shape(t1_shape) + elif op == OP.LEAKYRELU: + return LEAKYRELU.result_shape(t1_shape) elif op == OP.TANH: return TANH.result_shape(t1_shape) elif op == OP.TRANSPOSE: @@ -135,6 +146,10 @@ fn static_result_shape( return UNSQUEEZE.result_shape(t1_shape, attributes) elif op == OP.SLICE: return SLICE.result_shape(t1_shape, attributes) + elif op == OP.INDEX: + return INDEX.result_shape(t1_shape, attributes) + elif op == OP.UPSAMPLE: + return UPSAMPLE.result_shape(t1_shape, attributes) else: print("[ERROR] Operator not found.") return TensorShape(-1) @@ -235,6 +250,8 @@ fn forward_op[ SIGMOID.forward[t1_shape](res, t1) elif op == OP.RELU: RELU.forward[t1_shape](res, t1) + elif op == OP.LEAKYRELU: + LEAKYRELU.forward[t1_shape, attributes](res, t1) elif op == OP.TANH: TANH.forward[t1_shape](res, t1) elif op == OP.TRANSPOSE: @@ -249,12 +266,19 @@ fn forward_op[ UNSQUEEZE.forward[t1_shape, attributes](res, t1) elif op == OP.SLICE: SLICE.forward[t1_shape, attributes](res, t1) + elif op == OP.INDEX: + INDEX.forward[t1_shape, attributes](res, t1) + elif op == OP.UPSAMPLE: + UPSAMPLE.forward[t1_shape, attributes](res, t1) else: print("[ERROR] Operator not found.") fn forward_op[ - op: OP, t1_shape: TensorShape, t2_shape: TensorShape, attributes: AttributeVector + op: OP, + t1_shape: TensorShape, + t2_shape: TensorShape, + attributes: AttributeVector, ](inout res: Tensor[dtype], t1: Tensor[dtype], t2: Tensor[dtype]): """ Forward pass for binary operators. @@ -283,14 +307,21 @@ fn forward_op[ t2_shape: TensorShape, t3_shape: TensorShape, attributes: AttributeVector, -](inout res: Tensor[dtype], t1: Tensor[dtype], t2: Tensor[dtype], t3: Tensor[dtype]): +]( + inout res: Tensor[dtype], + t1: Tensor[dtype], + t2: Tensor[dtype], + t3: Tensor[dtype], +): """ Forward pass for ternary operators. """ @parameter if op == OP.CONV2D: - CONV2D.forward[t1_shape, t2_shape, t3_shape, attributes](res, t1, t2, t3) + CONV2D.forward[t1_shape, t2_shape, t3_shape, attributes]( + res, t1, t2, t3 + ) elif op == OP.FMA: FMA.forward[t1_shape, t2_shape, t3_shape](res, t1, t2, t3) else: @@ -303,7 +334,7 @@ fn forward_op[ ]( inputs: List[Symbol], outputs: List[Symbol], - parameters: Parameters, + inout parameters: Parameters, ): """ Forward pass for dynamic operators. @@ -347,6 +378,8 @@ fn backward_op[ res_grad = SIGMOID.backward[ug_shape, t1_shape](ug, t1) elif op == OP.RELU: res_grad = RELU.backward[ug_shape, t1_shape](ug, t1) + elif op == OP.LEAKYRELU: + res_grad = LEAKYRELU.backward[ug_shape, t1_shape, attributes](ug, t1) elif op == OP.TANH: res_grad = TANH.backward[ug_shape, t1_shape](ug, t1) elif op == OP.TRANSPOSE: @@ -361,6 +394,8 @@ fn backward_op[ res_grad = UNSQUEEZE.backward[ug_shape, t1_shape](ug, t1) elif op == OP.SLICE: res_grad = SLICE.backward[ug_shape, t1_shape, attributes](ug, t1) + elif op == OP.INDEX: + res_grad = INDEX.backward[ug_shape, t1_shape, attributes](ug, t1) else: print("[ERROR] Operator not found.") res_grad = Tensor[dtype](-1) @@ -375,7 +410,12 @@ fn backward_op[ t1_shape: TensorShape, t2_shape: TensorShape, attributes: AttributeVector, -](ug: Tensor[dtype], t1: Tensor[dtype], t2: Tensor[dtype], inout grad: Tensor[dtype]): +]( + ug: Tensor[dtype], + t1: Tensor[dtype], + t2: Tensor[dtype], + inout grad: Tensor[dtype], +): """ Backward pass for binary operators. """ @@ -383,17 +423,29 @@ fn backward_op[ @parameter if op == OP.ADD: - res_grad = ADD.backward[tensor_id, ug_shape, t1_shape, t2_shape](ug, t1, t2) + res_grad = ADD.backward[tensor_id, ug_shape, t1_shape, t2_shape]( + ug, t1, t2 + ) elif op == OP.SUB: - res_grad = SUB.backward[tensor_id, ug_shape, t1_shape, t2_shape](ug, t1, t2) + res_grad = SUB.backward[tensor_id, ug_shape, t1_shape, t2_shape]( + ug, t1, t2 + ) elif op == OP.MUL: - res_grad = MUL.backward[tensor_id, ug_shape, t1_shape, t2_shape](ug, t1, t2) + res_grad = MUL.backward[tensor_id, ug_shape, t1_shape, t2_shape]( + ug, t1, t2 + ) elif op == OP.DIV: - res_grad = DIV.backward[tensor_id, ug_shape, t1_shape, t2_shape](ug, t1, t2) + res_grad = DIV.backward[tensor_id, ug_shape, t1_shape, t2_shape]( + ug, t1, t2 + ) elif op == OP.POW: - res_grad = POW.backward[tensor_id, ug_shape, t1_shape, t2_shape](ug, t1, t2) + res_grad = POW.backward[tensor_id, ug_shape, t1_shape, t2_shape]( + ug, t1, t2 + ) elif op == OP.DOT: - res_grad = DOT.backward[tensor_id, ug_shape, t1_shape, t2_shape](ug, t1, t2) + res_grad = DOT.backward[tensor_id, ug_shape, t1_shape, t2_shape]( + ug, t1, t2 + ) else: print("[ERROR] Operator not found.") res_grad = Tensor[dtype](-1, -1) @@ -437,9 +489,9 @@ fn backward_op[ tensor_id, ug_shape, t1_shape, t2_shape, t3_shape, attributes ](ug, t1, t2, t3) elif op == OP.FMA: - res_grad = FMA.backward[tensor_id, ug_shape, t1_shape, t2_shape, t3_shape]( - ug, t1, t2, t3 - ) + res_grad = FMA.backward[ + tensor_id, ug_shape, t1_shape, t2_shape, t3_shape + ](ug, t1, t2, t3) else: print("[ERROR] Operator not found.") res_grad = Tensor[dtype](-1, -1) @@ -455,7 +507,7 @@ fn backward_op[ inputs: List[Symbol], outputs: List[Symbol], inout grad: Tensor[dtype], - parameters: Parameters, + inout parameters: Parameters, ): """ Backward pass for dynamic operators. @@ -463,9 +515,13 @@ fn backward_op[ var res_grad: Tensor[dtype] if op == OP.CONCAT: - res_grad = CONCAT.backward[input_id, attributes](inputs, outputs, parameters) + res_grad = CONCAT.backward[input_id, attributes]( + inputs, outputs, parameters + ) elif op == OP.SPLIT: - res_grad = SPLIT.backward[input_id, attributes](inputs, outputs, parameters) + res_grad = SPLIT.backward[input_id, attributes]( + inputs, outputs, parameters + ) else: print("[ERROR] Operator not found.") res_grad = Tensor[dtype](-1, -1) diff --git a/basalt/autograd/ops/pool.mojo b/basalt/autograd/ops/pool.mojo index 5e927407..3149cc10 100644 --- a/basalt/autograd/ops/pool.mojo +++ b/basalt/autograd/ops/pool.mojo @@ -1,4 +1,4 @@ -from math.limit import neginf +from utils.numerics import min_or_neg_inf from basalt import Tensor, TensorShape from basalt.autograd.attributes import AttributeVector @@ -48,7 +48,7 @@ struct MAXPOOL2D: for in_ch in range(input_shape[1]): for x in range(output_shape[2]): for y in range(output_shape[3]): - var max_val: Scalar[dtype] = neginf[dtype]() + var max_val: Scalar[dtype] = min_or_neg_inf[dtype]() var ix_base = x * stride[0] - padding[0] var iy_base = y * stride[1] - padding[1] for kx in range(kernel_size[0]): @@ -107,7 +107,7 @@ struct MAXPOOL2D: for in_ch in range(input_shape[1]): for x in range(ug_shape[2]): for y in range(ug_shape[3]): - var max_val: Scalar[dtype] = neginf[dtype]() + var max_val: Scalar[dtype] = min_or_neg_inf[dtype]() var max_idx: Int = -1 var ix_base = x * stride[0] - padding[0] var iy_base = y * stride[1] - padding[1] diff --git a/basalt/autograd/params.mojo b/basalt/autograd/params.mojo index 5d828489..37d682a8 100644 --- a/basalt/autograd/params.mojo +++ b/basalt/autograd/params.mojo @@ -19,10 +19,8 @@ struct Param(CollectionElement, Stringable): self.data = data self.initializer = None - fn __init__(inout self, a: Scalar[dtype]): - var data = List[Scalar[dtype]]() - data.append(a) - self.data = data + fn __init__(inout self, data: Scalar[dtype]): + self.data = List[Scalar[dtype]](data) self.initializer = None fn __init__(inout self, initializer: String, *args: Scalar[dtype]): diff --git a/basalt/nn/__init__.mojo b/basalt/nn/__init__.mojo index 99b30a31..d85ab275 100644 --- a/basalt/nn/__init__.mojo +++ b/basalt/nn/__init__.mojo @@ -4,6 +4,14 @@ from .model import Model from .layers.linear import Linear from .layers.conv import Conv2d from .layers.pool import MaxPool2d +from .layers.upsample import Upsample from .loss import MSELoss, CrossEntropyLoss -from .activations import Softmax, LogSoftmax, ReLU, Sigmoid, Tanh +from .activations import ( + Softmax, + LogSoftmax, + ReLU, + LeakyReLU, + Sigmoid, + Tanh, +) diff --git a/basalt/nn/activations.mojo b/basalt/nn/activations.mojo index 2264a541..9a83a0fd 100644 --- a/basalt/nn/activations.mojo +++ b/basalt/nn/activations.mojo @@ -2,13 +2,22 @@ from basalt import Tensor, TensorShape from basalt import Graph, Symbol, OP from basalt.autograd.attributes import Attribute, AttributeVector -# '''Activation functions.''' - +# '''Activation functions.''' fn ReLU(inout g: Graph, input: Symbol) -> Symbol: return g.op(OP.RELU, input) +fn LeakyReLU( + inout g: Graph, input: Symbol, negative_slope: Scalar[dtype] +) -> Symbol: + return g.op( + OP.LEAKYRELU, + input, + attributes=AttributeVector(Attribute("negative_slope", negative_slope)), + ) + + fn Sigmoid(inout g: Graph, input: Symbol) -> Symbol: return g.op(OP.SIGMOID, input) diff --git a/basalt/nn/layers/upsample.mojo b/basalt/nn/layers/upsample.mojo new file mode 100644 index 00000000..c70de00e --- /dev/null +++ b/basalt/nn/layers/upsample.mojo @@ -0,0 +1,117 @@ +from basalt import dtype +from basalt import Graph, Symbol, OP +from basalt import Tensor, TensorShape +from basalt.autograd.attributes import AttributeVector, Attribute +from basalt.utils.itertools import product + + +fn _scale_indeces(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> List[Scalar[dtype]]: + var M = int(scale * N) + var indeces = List[Scalar[dtype]]() + if align_corners: + for i in range(M): + indeces.append(i * ((N - 1) / (M - 1))) + else: + var step = 1 / scale + var start = ((M - 1) * step - N + 1) / 2 + for i in range(M): + indeces.append(i * step - start) + + return indeces ^ + + +fn nearest_coeffs(N: Int, scale: Scalar[dtype], dim: Int, ndims: Int) -> List[Int]: + + @parameter + fn round_to_index(number: Scalar[dtype]) -> Int: + return int(number + 0.5) if number > 0 else int(number - 0.5) + + var indeces = List[Int]() + var scaled = _scale_indeces(N, scale, True, dim, ndims) + for i in range(len(scaled)): + indeces.append(round_to_index(scaled[i])) + return indeces ^ + + +fn linear_coeffs(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> Tuple[List[Int], List[Int]]: + # TODO + return (List[Int](), List[Int]()) + + +fn cubic_coeffs(N: Int, scale: Scalar[dtype], align_corners: Bool, dim: Int, ndims: Int) -> Tuple[List[Int], List[Int]]: + # TODO + return (List[Int](), List[Int]()) + + +fn interpolate_nd[ + indices_fn: fn (Int, Scalar[dtype], Bool, Int, Int) -> Tuple[List[Int], List[Int]], +](inout g: Graph, input: Symbol, scale_factors: List[Scalar[dtype]], align_corners: Bool) -> Symbol: + + var spatial_dims = input.shape.rank() - 2 + + var indeces_weights = List[Tuple[List[Int], List[Int]]]() + for i in range(spatial_dims): + indeces_weights.append( + indices_fn( + input.shape[i + 2], + scale_factors[i], + align_corners, + i, + spatial_dims, + ) + ) + + # TODO: interpolation logic + # for idx_weight in product(indeces_weights): + # ... + + return Symbol(-1, dtype, TensorShape(), False) + + +fn Upsample( + inout g: Graph, + input: Symbol, + mode: StringLiteral, + scale_factors: List[Scalar[dtype]], + align_corners: Bool = False, +) -> Symbol: + + # Assumption: A scale needs to be provided for each spatial dimension. + # input shape (B, C, *N) with batch and channel considered as non-spatial dimensions. + # input.shape.rank() - 2 == len(scale_factor) + var spatial_dims = input.shape.rank() - 2 + + var res: Symbol + var attributes = AttributeVector() + var INDEX_LITERALS = List[StringLiteral]("dim_2i", "dim_3i", "dim_4i") + + if mode == "nearest": + # Nearest neighbor interpolation --> input[:, :, *indeces] + for i in range(spatial_dims): + attributes.append( + Attribute( + INDEX_LITERALS[i], + nearest_coeffs(input.shape[i + 2], scale_factors[i], i, spatial_dims) + ) + ) + + res = g.op(OP.INDEX, input, attributes=attributes) + + # elif mode == "linear": + # res = interpolate_nd[linear_coeffs](g, + # input, + # scale_factor, + # align_corners + # ) + + # elif mode == "cubic": + # res = interpolate_nd[cubic_coeffs](g, + # input, + # scale_factor, + # align_corners + # ) + else: + res = input + + return res + diff --git a/basalt/nn/model.mojo b/basalt/nn/model.mojo index ed80c7ec..a8993cf3 100644 --- a/basalt/nn/model.mojo +++ b/basalt/nn/model.mojo @@ -80,7 +80,7 @@ struct Model[ # TODO: remove when ability to concatenate graphs (modules) # Removes the need for splitting in forward and inference mode - fn forward(inout self, *t_inputs: Tensor[dtype]) -> Tensor[dtype]: + fn forward(inout self, *t_inputs: Tensor[dtype]) -> ref[__lifetime_of(self)] Tensor[dtype]: # NOTE: Important detail here is that the order of the inputs must be the same as the order the inputs were defined in the graph. # Example: If you were te define the y_true before the x when creating the graph # @@ -117,7 +117,7 @@ struct Model[ # 2. Loop over all nodes and execute forward operations @parameter - fn fw_unroll[i: Int](): + for i in range(num_nodes): alias op = g.nodes[i].operator alias attrs = g.nodes[i].attributes @@ -169,8 +169,6 @@ struct Model[ if DEBUG == 1: self.perf_metrics.end_forward_pass(i) - unroll[fw_unroll, num_nodes]() - fn backward(inout self, *upper_grads: Tensor[dtype]): """ Main entrypoint of backward pass. @@ -191,7 +189,7 @@ struct Model[ # 2. Loop over all nodes in reverse order and execute backward operations @parameter - fn bw_unroll[i: Int](): + for i in range(g.nodes.size): alias reverse_i = g.nodes.size - i - 1 alias op = g.nodes[reverse_i].operator alias attrs = g.nodes[reverse_i].attributes @@ -206,7 +204,7 @@ struct Model[ if op.dynamic: @parameter - fn unroll_dynamic[j: Int](): + for j in range(num_operands): @parameter if g.nodes[reverse_i].inputs[j].trainable: backward_op[j, op, attrs]( @@ -215,9 +213,6 @@ struct Model[ self.parameters.grads[g.nodes[reverse_i].inputs[j]], self.parameters, ) - - unroll[unroll_dynamic, num_operands]() - else: # Statically known shapes and number of operands alias out = g.nodes[reverse_i].outputs[0] # or upper_grad symbol @@ -302,8 +297,6 @@ struct Model[ if DEBUG == 1: self.perf_metrics.end_backward_pass(i) - unroll[bw_unroll, g.nodes.size]() - fn allocate_tensor_memory(inout self): for i in range(len(g.inputs)): self.parameters.tensors.append( @@ -375,7 +368,7 @@ struct Model[ except e: print("Error loading model data:", e) - fn export_model(self, model_path: String): + fn export_model(inout self, model_path: String): var path = Path(model_path) print("Exporting model to:", path) diff --git a/basalt/nn/optim.mojo b/basalt/nn/optim.mojo index 1ba90f2b..db6210a0 100644 --- a/basalt/nn/optim.mojo +++ b/basalt/nn/optim.mojo @@ -1,9 +1,10 @@ -from math import add, mul, div, sqrt, sub +from math import sqrt from algorithm import vectorize, parallelize from .model import Parameters from basalt import Graph, Tensor, TensorShape from basalt.utils.collection import Collection +from basalt.utils.math_util import add, sub, mul, div fn get_trainable_parameters(g: Graph) -> List[Symbol]: @@ -20,13 +21,14 @@ fn get_trainable_parameters(g: Graph) -> List[Symbol]: return trainable_parameters ^ +@value struct Adam[ + lifetime: MutableLifetime, # using mutability and anylifetime, seems to give problem for now because the the reference can't now for sure if the lifetime is mutable or not + //, g: Graph, - mutability: __mlir_type.i1, - lifetime: AnyLifetime[mutability].type, trainable_parameters: List[Symbol] = get_trainable_parameters(g), ]: - var parameters: Reference[Parameters, mutability, lifetime] + var parameters: Reference[Parameters, True, lifetime] var lr: Scalar[dtype] var beta1: Scalar[dtype] @@ -39,7 +41,7 @@ struct Adam[ fn __init__( inout self, - parameters: Reference[Parameters, mutability, lifetime], + parameters: Reference[Parameters, True, lifetime], lr: Scalar[dtype] = 0.001, beta1: Scalar[dtype] = 0.9, beta2: Scalar[dtype] = 0.999, diff --git a/basalt/nn/tensor.mojo b/basalt/nn/tensor.mojo index b3fa5513..63fb02ad 100644 --- a/basalt/nn/tensor.mojo +++ b/basalt/nn/tensor.mojo @@ -1,4 +1,3 @@ -from math import min from testing import assert_true from algorithm import vectorize @@ -14,40 +13,34 @@ struct TensorShape(Stringable): var _rank: Int var _shape: StaticIntTuple[MAX_RANK] - @always_inline("nodebug") fn __init__(inout self, *shape: Int): self._rank = len(shape) self._shape = StaticIntTuple[MAX_RANK]() for i in range(min(self._rank, MAX_RANK)): self._shape[i] = shape[i] - @always_inline("nodebug") fn __init__(inout self, shapes: VariadicList[Int]): self._rank = len(shapes) self._shape = StaticIntTuple[MAX_RANK]() for i in range(min(self._rank, MAX_RANK)): self._shape[i] = shapes[i] - @always_inline("nodebug") fn __init__(inout self, shape: List[Int]): self._rank = len(shape) self._shape = StaticIntTuple[MAX_RANK]() for i in range(min(self._rank, MAX_RANK)): self._shape[i] = shape[i] - @always_inline("nodebug") fn __init__[num: Int](inout self, shape: StaticIntTuple[num]): self._rank = num self._shape = StaticIntTuple[MAX_RANK]() for i in range(min(self._rank, MAX_RANK)): self._shape[i] = shape[i] - @always_inline("nodebug") fn __init__(inout self, rank: Int, shape: StaticIntTuple[MAX_RANK]): self._rank = rank self._shape = shape - @always_inline("nodebug") fn __init__(inout self, owned shape: _TensorShape): self._rank = shape.rank() self._shape = StaticIntTuple[MAX_RANK]() @@ -117,19 +110,16 @@ struct Tensor[dtype: DType](Stringable, Movable, CollectionElement): var _data: DTypePointer[dtype] var _shape: TensorShape - @always_inline("nodebug") fn __init__(inout self, *dims: Int): self._shape = TensorShape(dims) self._data = DTypePointer[dtype].alloc(self._shape.num_elements()) memset_zero(self._data, self._shape.num_elements()) - @always_inline("nodebug") fn __init__(inout self, owned shape: TensorShape): self._data = DTypePointer[dtype].alloc(shape.num_elements()) memset_zero(self._data, shape.num_elements()) self._shape = shape - @always_inline("nodebug") fn __init__( inout self, owned data: DTypePointer[dtype], owned shape: TensorShape ): @@ -140,20 +130,17 @@ struct Tensor[dtype: DType](Stringable, Movable, CollectionElement): memcpy(self._data, data, self._shape.num_elements()) _ = data - @always_inline("nodebug") fn __init__(inout self, owned tensor: _Tensor[dtype]): self._data = DTypePointer[dtype].alloc(tensor.num_elements()) self._shape = tensor.shape() - memcpy(self._data, tensor.data(), self._shape.num_elements()) + memcpy(self._data, tensor.unsafe_ptr(), self._shape.num_elements()) _ = tensor - @always_inline("nodebug") fn __moveinit__(inout self, owned other: Tensor[dtype]): self._data = other._data self._shape = other._shape - @always_inline("nodebug") fn __copyinit__(inout self, other: Tensor[dtype]): # print("[WARNING] Copying tensor") self._data = DTypePointer[dtype].alloc(other._shape.num_elements()) diff --git a/basalt/utils/bytes.mojo b/basalt/utils/bytes.mojo index 498851b0..8125a307 100644 --- a/basalt/utils/bytes.mojo +++ b/basalt/utils/bytes.mojo @@ -1,5 +1,6 @@ from math import nan -from math.limit import inf +from utils.numerics import inf +from utils.static_tuple import StaticTuple alias ScalarBytes = DType.uint64.sizeof() @@ -12,22 +13,18 @@ struct Bytes[capacity: Int](Stringable, CollectionElement, EqualityComparable): var data: StaticTuple[UInt8, capacity] - @always_inline("nodebug") fn __init__(inout self): - var data = StaticTuple[UInt8, capacity]() + var data = StaticTuple[UInt8, capacity](0) - @unroll for i in range(capacity): data[i] = 0 self.data = data - @always_inline("nodebug") fn __init__(inout self, s: String): - var data = StaticTuple[UInt8, capacity]() + var data = StaticTuple[UInt8, capacity](0) var length = len(s) - @unroll for i in range(capacity): data[i] = ord(s[i]) if i < length else 0 @@ -47,7 +44,6 @@ struct Bytes[capacity: Int](Stringable, CollectionElement, EqualityComparable): @always_inline("nodebug") fn __eq__(self, other: Self) -> Bool: - @unroll for i in range(capacity): if self[i] != other[i]: return False @@ -55,7 +51,6 @@ struct Bytes[capacity: Int](Stringable, CollectionElement, EqualityComparable): @always_inline("nodebug") fn __ne__(self, other: Self) -> Bool: - @unroll for i in range(capacity): if self[i] != other[i]: return True @@ -65,7 +60,6 @@ struct Bytes[capacity: Int](Stringable, CollectionElement, EqualityComparable): fn __str__(self) -> String: var result: String = "" - @unroll for i in range(capacity): var val = self[i] if val != 0: @@ -82,7 +76,6 @@ fn scalar_to_bytes[ var bits = bitcast[DType.uint64](value.cast[expand_type[dtype]()]()) var data = Bytes[Size]() - @unroll for i in range(ScalarBytes): data[i] = (bits >> (i << 3)).cast[DType.uint8]() @@ -94,7 +87,6 @@ fn bytes_to_scalar[dtype: DType](data: Bytes) -> Scalar[dtype]: var bits: UInt64 = 0 - @unroll for i in range(ScalarBytes): bits |= data[i].cast[DType.uint64]() << (i << 3) diff --git a/basalt/utils/collection.mojo b/basalt/utils/collection.mojo index 0a8aea91..1528844d 100644 --- a/basalt/utils/collection.mojo +++ b/basalt/utils/collection.mojo @@ -1,7 +1,6 @@ -from math import max -from memory.unsafe_pointer import UnsafePointer, move_from_pointee, initialize_pointee_copy, initialize_pointee_move, destroy_pointee +from memory.unsafe_pointer import UnsafePointer, initialize_pointee_move, destroy_pointee -from basalt import Tensor, TensorShape, Symbol +from basalt import Tensor, Symbol struct Collection(CollectionElement, Sized): @@ -108,26 +107,42 @@ struct Collection(CollectionElement, Sized): fn get_index(self, symbol_name: UInt32) -> Int: """ Returns the index of the tensor with the given symbol name. - """ - for i in range(self.size): - if self.symbols[i] == symbol_name: - return i + """ + alias factor = 8 + # 2 -> 5.32s MNIST + # 4 -> 4.95s MNIST + # 8 -> 4.85s MNIST + # 16 -> 5.19s MNIST + # NOTE: This ideally should just be a hashmap + + for i in range(0, self.size, factor): + var elems = self.symbols.load[width=factor](i) == symbol_name + + for j in range(factor): + if elems[j]: + return i + j + + var split = divmod(self.size, factor) + + for i in range(split[1]): + var index = split[0] + i + + if self.symbols[index] == symbol_name: + return index + return -1 - @always_inline("nodebug") - fn __refitem__[ - mutability: __mlir_type.i1, - lifetime: AnyLifetime[mutability].type, - ]( - self: Reference[Self, mutability, lifetime]._mlir_type, + fn __getitem__( + inout self, symbol: Symbol, - ) -> Reference[Tensor[dtype], mutability, lifetime]: + ) -> ref[__lifetime_of(self)] Tensor[dtype]: """ Returns a reference to the tensor with the given symbol. """ - var index = Reference(self)[].get_index(symbol.name) + var index = self.get_index(symbol.name) + - return (Reference(self)[].data + index)[] + return (self.data + index)[0] @always_inline("nodebug") fn clear(inout self): diff --git a/basalt/utils/datasets.mojo b/basalt/utils/datasets.mojo index cb019ae9..ff5b3562 100644 --- a/basalt/utils/datasets.mojo +++ b/basalt/utils/datasets.mojo @@ -1,11 +1,15 @@ from algorithm import vectorize -from math import div from basalt import dtype from basalt import Tensor, TensorShape from basalt.utils.tensorutils import elwise_op, tmean, tstd +@always_inline +fn div[dtype: DType, simd_width: Int](a: SIMD[dtype, simd_width], b: Scalar[dtype]) -> SIMD[dtype, simd_width]: + return a / b + + struct BostonHousing: alias n_inputs = 13 diff --git a/basalt/utils/itertools.mojo b/basalt/utils/itertools.mojo new file mode 100644 index 00000000..2b7d3abf --- /dev/null +++ b/basalt/utils/itertools.mojo @@ -0,0 +1,49 @@ + +@value +struct _ProductIterator(Sized): + var lists: List[List[Int]] + var _current: Int + var _iters: Int + + @always_inline("nodebug") + fn __init__(inout self, lists: List[List[Int]]): + self.lists = lists + self._current = 0 + + self._iters = 1 + for lst in self.lists: + self._iters *= len(lst[]) + + @always_inline("nodebug") + fn __len__(self) -> Int: + return self._iters + + @always_inline("nodebug") + fn __iter__(self) -> Self: + return self + + @always_inline("nodebug") + fn __next__(inout self) -> List[Int]: + self._current += 1 + self._iters -= 1 + return self._get_combination(self._current - 1) + + @always_inline("nodebug") + fn _get_combination(self, current: Int) -> List[Int]: + var combination = List[Int]() + var count = current + for i in reversed(range(len(self.lists))): + var index = count % len(self.lists[i]) + combination.append(self.lists[i][index]) + count //= len(self.lists[i]) + combination.reverse() + return combination ^ + + @always_inline("nodebug") + fn __getitem__(self, index: Int) -> List[Int]: + return self._get_combination(index) + + +@always_inline("nodebug") +fn product(lists: List[List[Int]]) -> _ProductIterator: + return _ProductIterator(lists) \ No newline at end of file diff --git a/basalt/utils/math_util.mojo b/basalt/utils/math_util.mojo new file mode 100644 index 00000000..faeab908 --- /dev/null +++ b/basalt/utils/math_util.mojo @@ -0,0 +1,41 @@ +@always_inline +fn add[ + dtype: DType, simd_width: Int +](a: SIMD[dtype, simd_width], b: SIMD[dtype, simd_width]) -> SIMD[ + dtype, simd_width +]: + return a + b + + +@always_inline +fn sub[ + dtype: DType, simd_width: Int +](a: SIMD[dtype, simd_width], b: SIMD[dtype, simd_width]) -> SIMD[ + dtype, simd_width +]: + return a - b + + +@always_inline +fn mul[ + dtype: DType, simd_width: Int +](a: SIMD[dtype, simd_width], b: SIMD[dtype, simd_width]) -> SIMD[ + dtype, simd_width +]: + return a * b + + +@always_inline +fn div[ + dtype: DType, simd_width: Int +](a: SIMD[dtype, simd_width], b: SIMD[dtype, simd_width]) -> SIMD[ + dtype, simd_width +]: + return a / b + + +@always_inline +fn round_simd[ + dtype: DType, simd_width: Int +](x: SIMD[dtype, simd_width]) -> SIMD[dtype, simd_width]: + return round(x) diff --git a/basalt/utils/onnx_utils.mojo b/basalt/utils/onnx_utils.mojo index 9eeda441..fde4d909 100644 --- a/basalt/utils/onnx_utils.mojo +++ b/basalt/utils/onnx_utils.mojo @@ -6,33 +6,13 @@ from basalt.nn.model import Parameters from basalt.nn.tensor import Tensor, TensorShape from basalt.autograd.attributes import Attribute, AttributeType from basalt.autograd.ops import OP +from basalt.autograd.graph import Node + +from .tensor_creation_utils import to_numpy, copy_np_data # NOTE: Maybe we could create our own model representation and from there convert to onnx or others (well we already have it in reallity) # NOTE: Torch doesn't import onnx, need onnx2torch and it doesn't support operators like reshape? -fn to_numpy(tensor: Tensor) raises -> PythonObject: - var np = Python.import_module("numpy") - - np.set_printoptions(4) - var rank = tensor.rank() - var pyarray: PythonObject = np.array([0]) - - if rank == 1: - pyarray = np.empty((tensor.dim(0))) - elif rank == 2: - pyarray = np.empty((tensor.dim(0), tensor.dim(1))) - elif rank == 3: - pyarray = np.empty((tensor.dim(0), tensor.dim(1), tensor.dim(2))) - elif rank == 4: - pyarray = np.empty((tensor.dim(0), tensor.dim(1), tensor.dim(2), tensor.dim(3))) - else: - print("Error: rank not supported: ", rank) - - for i in range(tensor.num_elements()): - pyarray.itemset((i), tensor[i]) - - return pyarray - fn make_onnx_attribute(op: OP, attr: Attribute) raises -> PythonObject: var onnx = Python.import_module("onnx") @@ -68,9 +48,7 @@ fn make_onnx_attribute(op: OP, attr: Attribute) raises -> PythonObject: else: raise Error("Unsupported attribute name for operator " + str(op)) - if (op == OP.CONV2D and attr_name) == "pads" or ( - op == OP.MAXPOOL2D and attr_name - ) == "pads": + if (op == OP.CONV2D or op == OP.MAXPOOL2D) and attr_name == "pads": # Special case for pads in conv and maxpool, onnx wants pads to be [x1_begin, x2_begin…x1_end, x2_end,…], attr_value.append(attr_value[0]) attr_value.append(attr_value[1]) @@ -185,21 +163,94 @@ fn load_onnx_model( "Shape mismatch for tensor " + str(i) + ". Expected shape: " - + model_tensor_shape + + str(model_tensor_shape) + ", got shape: " - + data_shape + + str(data_shape) ) - var data = data_np.flatten() - - # It would be better to use memcpy here - for j in range(len(data)): - model_parameters.tensors[g.params.symbols[i]][j] = data[j].to_float64() + copy_np_data(model_parameters.tensors[g.params.symbols[i]], data_np) else: raise Error("Unsupported data type") -fn export_onnx_model(model_path: Path, model_parameters: Parameters, g: Graph) raises: +fn create_attributes_and_constant_inputs(node: Node, node_number: Int) raises -> (List[PythonObject], List[PythonObject]): + var onnx = Python.import_module("onnx") + var np = Python.import_module("numpy") + + var attributes = List[PythonObject]() + var inputs = List[PythonObject]() + + for i in range(len(node.attributes)): + var attr = node.attributes[i] + + @parameter + fn to_np_array(attr: Attribute) raises -> PythonObject: + if not attr.type == AttributeType.INTS: + raise Error("Attribute is not a shape") + + var values_np: PythonObject + if attr.type == AttributeType.INTS: + var shape = attr.to_shape() + values_np = PythonObject([]) + for j in range(shape.rank()): + values_np.append(shape[j]) + elif attr.type == AttributeType.FLOAT: + values_np = attr.to_scalar[DType.float64]() + elif attr.type == AttributeType.INT: + values_np = attr.to_int() + else: + raise Error("Unsupported attribute type") + + var np_array = np.array(values_np, dtype=np.int64) + + return onnx.numpy_helper.from_array(np_array) + + # Special cases where attributes are considered as inputs, so we create Constant inputs + if node.operator == OP.RESHAPE: + if str(attr.name) == "shape": + var outputs = PythonObject([]) + outputs.append(str(node.operator) + "_" + str(attr.name) + "_" + str(node_number)) + var temp_node = onnx.helper.make_node( + op_type="Constant", + inputs=[], + outputs=outputs, + value=to_np_array(attr), + ) + + inputs.append(temp_node) + elif node.operator == OP.CLIP: + if str(attr.name) == "min" or str(attr.name) == "max": + var outputs = PythonObject([]) + outputs.append(str(node.operator) + "_" + str(attr.name) + "_" + str(node_number)) + var temp_node = onnx.helper.make_node( + op_type="Constant", + inputs=[], + outputs=outputs, + value=to_np_array(attr), + ) + + inputs.append(temp_node) + elif node.operator == OP.SQUEEZE or node.operator == OP.UNSQUEEZE: + if str(attr.name) == "dims": + var outputs = PythonObject([]) + outputs.append(str(node.operator) + "_" + str(attr.name) + "_" + str(node_number)) + var temp_node = onnx.helper.make_node( + op_type="Constant", + inputs=[], + outputs=outputs, + value=to_np_array(attr), + ) + + inputs.append(temp_node) + else: + var attr_value = make_onnx_attribute(node.operator, attr) + + attributes.append(attr_value) + + return (attributes, inputs) + + +fn export_onnx_model(model_path: Path, inout model_parameters: Parameters, g: Graph) raises: # Create onnx model with data and nodes var onnx = Python.import_module("onnx") var onnx_helper = Python.import_module("onnx.helper") @@ -238,7 +289,7 @@ fn export_onnx_model(model_path: Path, model_parameters: Parameters, g: Graph) r var op_type = make_onnx_operator_type(node.operator) var inputs = PythonObject([]) var outputs = PythonObject([]) - var name = str(node.operator) + "_node" + i + var name = str(node.operator) + "_node" + str(i) for j in range(len(node.inputs)): inputs.append(str(node.inputs[j].name)) @@ -261,6 +312,14 @@ fn export_onnx_model(model_path: Path, model_parameters: Parameters, g: Graph) r var onnx_output = onnx_helper.make_tensor_value_info(name, dtype, shape) graph.value_info.append(onnx_output) + # Process attributes + var attributes_and_inputs = create_attributes_and_constant_inputs(node, i) + var attributes = attributes_and_inputs[0] + var inputs_constant = attributes_and_inputs[1] + for j in range(len(inputs_constant)): + inputs.append(inputs_constant[j].output[0]) + graph.node.append(inputs_constant[j]) + # Create onnx node var onnx_node = onnx_helper.make_node( op_type, @@ -268,33 +327,8 @@ fn export_onnx_model(model_path: Path, model_parameters: Parameters, g: Graph) r outputs, name, ) - - # Process attributes - for j in range(len(node.attributes)): - var attr = node.attributes[j] - var attr_value = make_onnx_attribute(node.operator, attr) - - # Special case for reshape, shape in reshape is not an attribute, instead it is an input because they can be dynamic - if not node.operator == OP.RESHAPE: - onnx_node.attribute.append(attr_value) - - # Special case for reshape, shape in reshape is not an attribute, instead it is an input because they can be dynamic (it can be the result of another operator, don't know why) - if node.operator == OP.RESHAPE: - var shape = node.attributes[0].to_shape() - var list_shape = PythonObject([]) - for j in range(shape.rank()): - list_shape.append(shape[j]) - - graph.initializer.append( - onnx_helper.make_tensor( - name=name + "_shape", - data_type=onnx.TensorProto.INT64, - dims=(shape.rank(),), - vals=list_shape, - ) - ) - - onnx_node.input.append(name + "_shape") + for attribute in attributes: + onnx_node.attribute.append(attribute[]) graph.node.append(onnx_node) diff --git a/basalt/utils/perf_utils.mojo b/basalt/utils/perf_utils.mojo index 9cf076c7..bacd940e 100644 --- a/basalt/utils/perf_utils.mojo +++ b/basalt/utils/perf_utils.mojo @@ -1,5 +1,4 @@ from time import now -from math import min from memory import memset from basalt.autograd.node import Node @@ -7,10 +6,10 @@ from basalt.autograd.node import Node @always_inline("nodebug") fn fit_string[num: Int](s: String) -> String: - var data = DTypePointer[DType.int8]().alloc(num + 1) + var data = DTypePointer[DType.uint8]().alloc(num + 1) var copy_len = min(num, len(s)) - memcpy(data, s._as_ptr(), copy_len) + memcpy(data, s.unsafe_uint8_ptr(), copy_len) memset(data + copy_len, ord(" "), num - copy_len) data[num] = 0 @@ -20,11 +19,11 @@ fn fit_string[num: Int](s: String) -> String: @always_inline("nodebug") fn truncate_decimals[num: Int](s: String) -> String: try: - var parts = s.split(delimiter=".") + var parts = s.split(".") var truncated = parts[0] if len(parts) > 1: - var decimal_parts = parts[1].split(delimiter="e") + var decimal_parts = parts[1].split("e") truncated += "." + fit_string[num](decimal_parts[0]) if len(decimal_parts) > 1: @@ -125,7 +124,7 @@ struct PerfMetrics: print(header) var header_length = len(header) - var seperator = DTypePointer[DType.int8]().alloc(header_length + 1) + var seperator = DTypePointer[DType.uint8]().alloc(header_length + 1) memset(seperator, ord("-"), header_length) seperator[header_length] = 0 @@ -146,11 +145,11 @@ struct PerfMetrics: var print_value = ( fit_string[5](str(i)) + "| " - + fit_string[15](value.node.operator) + + fit_string[15](str(value.node.operator)) + "| " - + fit_string[20](truncate_decimals[4](time)) + + fit_string[20](truncate_decimals[4](str(time))) + "| " - + fit_string[20](truncate_decimals[3](percentage) + " %") + + fit_string[20](truncate_decimals[3](str(percentage)) + " %") + "| " ) diff --git a/basalt/utils/rand_utils.mojo b/basalt/utils/rand_utils.mojo index 69fd80f9..84b1925a 100644 --- a/basalt/utils/rand_utils.mojo +++ b/basalt/utils/rand_utils.mojo @@ -1,6 +1,7 @@ from basalt import Tensor from random import rand, randn from algorithm import vectorize +from utils.static_tuple import StaticTuple @always_inline @@ -71,4 +72,4 @@ struct MersenneTwister: return y fn next_ui8(inout self) -> UInt8: - return self.next().value & 0xFF + return self.next().value & int(0xFF) diff --git a/basalt/utils/tensor_creation_utils.mojo b/basalt/utils/tensor_creation_utils.mojo new file mode 100644 index 00000000..7662331d --- /dev/null +++ b/basalt/utils/tensor_creation_utils.mojo @@ -0,0 +1,77 @@ +from python import Python + +# maybe this functions should be from the Tensor struct (like tensor.to_numpy()) and tensor.__init__(np_array: PythonObject) to create a tensor from a numpy array and tensor.copy_np_data(np_array: PythonObject) to copy the numpy array to the tensor. + + +fn to_numpy(tensor: Tensor) -> PythonObject: + try: + var np = Python.import_module("numpy") + + np.set_printoptions(4) + + var rank = tensor.rank() + var dims = PythonObject([]) + for i in range(rank): + dims.append(tensor.dim(i)) + var pyarray: PythonObject = np.empty(dims, dtype=np.float32) + + var pointer = int(pyarray.__array_interface__["data"][0].to_float64()) + var pointer_d = DTypePointer[tensor.dtype](address=pointer) + memcpy(pointer_d, tensor.data(), tensor.num_elements()) + + _ = tensor + + return pyarray^ + except e: + print("Error in to numpy", e) + return PythonObject() + + +fn to_tensor(np_array: PythonObject) raises -> Tensor[dtype]: + var shape = List[Int]() + for i in range(np_array.ndim): + shape.append(int(np_array.shape[i].to_float64())) + if np_array.ndim == 0: + # When the numpy array is a scalar, you need or the reshape to a size 1 ndarray or do this, if not the memcpy gets a memory error (Maybe because it is a register value?). + var tensor = Tensor[dtype](TensorShape(1)) + tensor[0] = np_array.to_float64().cast[dtype]() + return tensor^ + + var tensor = Tensor[dtype](TensorShape(shape)) + + var np_array_2: PythonObject + try: + var np = Python.import_module("numpy") + # copy is also necessary for ops like slices to make them contiguous instead of references. + np_array_2 = np.float32(np_array.copy()) + except e: + np_array_2 = np_array.copy() + print("Error in to_tensor", e) + + var pointer = int(np_array_2.__array_interface__["data"][0].to_float64()) + var pointer_d = DTypePointer[tensor.dtype](address=pointer) + memcpy(tensor.data(), pointer_d, tensor.num_elements()) + + _ = np_array_2 + _ = np_array + + return tensor^ + + +fn copy_np_data(inout tensor: Tensor, np_array: PythonObject) raises: + var np_array_2: PythonObject + try: + var np = Python.import_module("numpy") + # copy is also necessary for ops like slices to make them contiguous instead of references. + np_array_2 = np.float32(np_array.copy()) + except e: + np_array_2 = np_array.copy() + print("Error in to_tensor", e) + + var pointer = int(np_array_2.__array_interface__["data"][0].to_float64()) + var pointer_d = DTypePointer[tensor.dtype](address=pointer) + memcpy(tensor.data(), pointer_d, tensor.num_elements()) + + _ = np_array_2 + _ = np_array + _ = tensor diff --git a/basalt/utils/tensorutils.mojo b/basalt/utils/tensorutils.mojo index fea82bcc..420ae5e4 100644 --- a/basalt/utils/tensorutils.mojo +++ b/basalt/utils/tensorutils.mojo @@ -1,13 +1,17 @@ from sys.info import num_physical_cores -from algorithm import vectorize, parallelize, swap +from algorithm import vectorize, parallelize from memory import memset_zero, memset, stack_allocation -from math import sqrt, pow, equal, max, min, add, div, divmod, abs +from math import sqrt from random import rand +from utils.numerics import min_finite, max_finite from basalt import Tensor, TensorShape from basalt.nn.tensor import MAX_RANK +from basalt.utils.math_util import add, sub, mul, div +# ---- Start ----- + @always_inline fn fill[dtype: DType](inout t: Tensor[dtype], val: Scalar[dtype]): @parameter @@ -48,14 +52,8 @@ fn broadcast_shapes(s1: TensorShape, s2: TensorShape) -> TensorShape: var ndim = max(s1.rank(), s2.rank()) var diff = abs(s1.rank() - s2.rank()) - var big: TensorShape - var small: TensorShape - if s1.rank() > s2.rank(): - big = s1 - small = s2 - else: - big = s2 - small = s1 + var big = s1 if s1.rank() > s2.rank() else s2 + var small = s2 if s1.rank() > s2.rank() else s1 var res = StaticIntTuple[MAX_RANK](-1) @@ -67,12 +65,7 @@ fn broadcast_shapes(s1: TensorShape, s2: TensorShape) -> TensorShape: elif a == 1 or b == 1: res[i] = a * b else: - # NOTE: consider assert and allow the function raises - var message: String = "[ERROR] Shapes " + str(s1) + " and " + str( - s2 - ) + " cannot be broadcasted together." - print(message) - # raise Error(message) + print("[ERROR] Shapes " + str(s1) + " and " + str(s2) + " cannot be broadcasted together.") for i in range(diff - 1, -1, -1): res[i] = big[i] @@ -91,9 +84,7 @@ fn broadcast_shapes(*s: TensorShape) -> TensorShape: @always_inline -fn broadcast_calculate_strides[ - size: Int, shape: TensorShape, broadcast_shape: TensorShape -]() -> StaticIntTuple[size]: +fn broadcast_calculate_strides[size: Int, shape: TensorShape, broadcast_shape: TensorShape]() -> StaticIntTuple[size]: alias shape_rank = shape.rank() alias diff = size - shape_rank @@ -107,181 +98,6 @@ fn broadcast_calculate_strides[ return strides - -# ----- Dot functions ----- -@always_inline -fn calculate_block[ - M: Int, N: Int, K: Int, BLOCK_M: Int, BLOCK_N: Int, nelts: Int -]( - res: DTypePointer[dtype], - t1: DTypePointer[dtype], - t2: DTypePointer[dtype], - bm: Int, - bn: Int, -): - # Compute tile - var acc = stack_allocation[BLOCK_M * BLOCK_N, dtype]() - memset_zero[dtype](acc, BLOCK_M * BLOCK_N) - - for k in range(K): - - @unroll - for m in range(BLOCK_M): - - @parameter - fn inner_n[nelts: Int](n: Int): - acc.store[width=nelts]( - m * BLOCK_N + n, - SIMD[dtype, nelts] - .splat(t1[(bm + m) * K + k]) - .fma( - t2.load[width=nelts](k * N + (bn + n)), - acc.load[width=nelts](m * BLOCK_N + n), - ), - ) - - vectorize[inner_n, nelts](BLOCK_N) - - # Store tile - for m in range(BLOCK_M): - - @parameter - fn vec_store[nelts: Int](n: Int): - res.store[width=nelts]( - (bm + m) * N + (bn + n), acc.load[width=nelts](m * BLOCK_N + n) - ) - - vectorize[vec_store, nelts](BLOCK_N) - - -@parameter -@always_inline -fn dot[ - t1_shape: TensorShape, t2_shape: TensorShape -](inout res: Tensor[dtype], t1: Tensor[dtype], t2: Tensor[dtype]): - dot[t1_shape, t2_shape](res.data(), t1.data(), t2.data()) - - -@parameter -@always_inline -fn dot[ - t1_shape: TensorShape, t2_shape: TensorShape -](res: DTypePointer[dtype], t1: DTypePointer[dtype], t2: DTypePointer[dtype]): - alias M = t1_shape[0] # t1[0] - alias K = t1_shape[1] # t1[1], t2[0] - alias N = t2_shape[1] # t2[1] - - # simdwidthof[dtype]() = 8 for float32 - alias nelts = simdwidthof[dtype]() - alias BLOCK_N = 8 * 2 - alias BLOCK_M = 6 - alias THREADS = 6 # num_logical_cores() - - alias BLOCK_N_REMAINDER = N % BLOCK_N - alias BLOCK_M_REMAINDER = M % BLOCK_M - - @parameter - fn bm_par(m_outer: Int): - var bm = m_outer * BLOCK_M - - for n_outer in range(0, N // BLOCK_N): - var bn = n_outer * BLOCK_N - - calculate_block[M, N, K, BLOCK_M, BLOCK_N, nelts](res, t1, t2, bm, bn) - - # Handle the remainder of N - @parameter - if BLOCK_N_REMAINDER > 0: - var bn = N - BLOCK_N_REMAINDER - - calculate_block[M, N, K, BLOCK_M, BLOCK_N_REMAINDER, nelts]( - res, t1, t2, bm, bn - ) - - parallelize[bm_par](M // BLOCK_M, M // BLOCK_M) - - # Handle the remainder of M - @parameter - if BLOCK_M_REMAINDER > 0: - var bm = M - BLOCK_M_REMAINDER - - for n_outer in range(0, N // BLOCK_N): - var bn = n_outer * BLOCK_N - - calculate_block[M, N, K, BLOCK_M_REMAINDER, BLOCK_N, nelts]( - res, t1, t2, bm, bn - ) - - # Handle corner remainder - @parameter - if BLOCK_N_REMAINDER > 0: - var bn = N - BLOCK_N_REMAINDER - - calculate_block[M, N, K, BLOCK_M_REMAINDER, BLOCK_N_REMAINDER, nelts]( - res, t1, t2, bm, bn - ) - - -fn dot_transpose_t2[ - A_shape: TensorShape, B_shape: TensorShape -](inout C: DTypePointer[dtype], A: DTypePointer[dtype], B: DTypePointer[dtype]): - dot[A_shape, TensorShape(B_shape[1], B_shape[0])](C, A, transpose_2D[B_shape](B)) - - -fn dot_transpose_t2[ - A_shape: TensorShape, B_shape: TensorShape -](inout C: Tensor[dtype], A: Tensor[dtype], B: Tensor[dtype]): - memset_zero[dtype](C.data(), C.num_elements()) - - dot[A_shape, TensorShape(B_shape[1], B_shape[0])](C, A, transpose_2D[B_shape](B)) - - # @parameter - # fn calc_row(i: Int): - # for j in range(B_shape[0]): - - # @parameter - # fn calc_row_A_B[nelts: Int](k: Int): - # var A_pos = i * A.dim(1) + k - # var B_pos = j * A.dim(1) + k - # var t_new_pos = i * C.dim(1) + j - - # C[t_new_pos] += ( - # A.load[nelts](A_pos) * B.load[nelts](B_pos) - # ).reduce_add() - - # vectorize[calc_row_A_B, nelts, size=A_shape[1]]() - - # parallelize[calc_row](A_shape[0], 1) - - -fn dot_transpose_t1[ - A_shape: TensorShape, B_shape: TensorShape -](inout C: Tensor[dtype], A: Tensor[dtype], B: Tensor[dtype]): - memset_zero[dtype](C.data(), C.num_elements()) - - dot[TensorShape(A_shape[1], A_shape[0]), B_shape](C, transpose_2D[A_shape](A), B) - - # @parameter - # fn calc_row(i: Int): - # for j in range(A_shape[0]): - - # @parameter - # fn calc_row_t_new_B[nelts: Int](k: Int): - # var A_pos = j * A.dim(1) + i - # var B_pos = j * B.dim(1) + k - # var t_new_pos = i * C.dim(1) + k - - # C.store[nelts]( - # t_new_pos, - # C.load[nelts](t_new_pos) - # + A[A_pos] * B.load[nelts](B_pos), - # ) - - # vectorize[calc_row_t_new_B, nelts, size=B_shape[1]]() - - # parallelize[calc_row](A_shape[1], 1) - - # ----- Element-wise unary operations ----- @always_inline fn elwise_transform[ @@ -660,13 +476,13 @@ fn _reduce_max[ @always_inline fn tmax(t: Tensor[dtype]) -> Scalar[dtype]: - var starting_value = math.limit.min_finite[dtype]() + var starting_value = min_finite[dtype]() return reduce[max, _reduce_max](t, starting_value) @always_inline fn tmax(inout res: Tensor[dtype], t: Tensor[dtype], axis: Int): - var starting_value = math.limit.min_finite[dtype]() + var starting_value = min_finite[dtype]() reduce[max, _reduce_max](res, t, axis, starting_value) diff --git a/examples/yolo_v8_utils.py b/examples/yolo_v8_utils.py new file mode 100644 index 00000000..e7b22e4d --- /dev/null +++ b/examples/yolo_v8_utils.py @@ -0,0 +1,174 @@ +import cv2 +import numpy as np + + +CLASSES = { + 0: "person", + 1: "bicycle", + 2: "car", + 3: "motorcycle", + 4: "airplane", + 5: "bus", + 6: "train", + 7: "truck", + 8: "boat", + 9: "traffic light", + 10: "fire hydrant", + 11: "stop sign", + 12: "parking meter", + 13: "bench", + 14: "bird", + 15: "cat", + 16: "dog", + 17: "horse", + 18: "sheep", + 19: "cow", + 20: "elephant", + 21: "bear", + 22: "zebra", + 23: "giraffe", + 24: "backpack", + 25: "umbrella", + 26: "handbag", + 27: "tie", + 28: "suitcase", + 29: "frisbee", + 30: "skis", + 31: "snowboard", + 32: "sports ball", + 33: "kite", + 34: "baseball bat", + 35: "baseball glove", + 36: "skateboard", + 37: "surfboard", + 38: "tennis racket", + 39: "bottle", + 40: "wine glass", + 41: "cup", + 42: "fork", + 43: "knife", + 44: "spoon", + 45: "bowl", + 46: "banana", + 47: "apple", + 48: "sandwich", + 49: "orange", + 50: "broccoli", + 51: "carrot", + 52: "hot dog", + 53: "pizza", + 54: "donut", + 55: "cake", + 56: "chair", + 57: "couch", + 58: "potted plant", + 59: "bed", + 60: "dining table", + 61: "toilet", + 62: "tv", + 63: "laptop", + 64: "mouse", + 65: "remote", + 66: "keyboard", + 67: "cell phone", + 68: "microwave", + 69: "oven", + 70: "toaster", + 71: "sink", + 72: "refrigerator", + 73: "book", + 74: "clock", + 75: "vase", + 76: "scissors", + 77: "teddy bear", + 78: "hair drier", + 79: "toothbrush", +} + +colors = np.random.uniform(0, 255, size=(len(CLASSES), 3)) + + +def bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h): + label = f'{CLASSES[class_id]} ({confidence:.2f})' + color = colors[class_id] + cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2) + cv2.putText(img, label, (x - 10, y - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) + + +def get_model_input(image_path): + # load the image + original_image: np.ndarray = cv2.imread(image_path) + height, width, _ = original_image.shape + length = max((height, width)) + image = np.pad(original_image, ((0, length - height), (0, length - width), (0, 0)), mode='constant', constant_values=0) + blob = cv2.dnn.blobFromImage(image, scalefactor=1 / 255, size=(640, 640), swapRB=True) + + return blob + + +def draw_bounding_box_yolo(original_image, outputs): + height, width, _ = original_image.shape + length = max((height, width)) + image = np.zeros((length, length, 3), np.uint8) + image[:height, :width] = original_image + scale = length / 640 + + outputs = np.array([cv2.transpose(outputs[0])]) + rows = outputs.shape[1] + + boxes = [] + scores = [] + class_ids = [] + + for i in range(rows): + classes_scores = outputs[0][i][4:] + (_, maxScore, _, (_, maxClassIndex)) = cv2.minMaxLoc(classes_scores) + if maxScore >= 0.25: + box = [ + outputs[0][i][0] - (0.5 * outputs[0][i][2]), + outputs[0][i][1] - (0.5 * outputs[0][i][3]), + outputs[0][i][2], + outputs[0][i][3] + ] + boxes.append(box) + scores.append(maxScore) + class_ids.append(maxClassIndex) + + result_boxes = cv2.dnn.NMSBoxes(boxes, scores, 0.25, 0.45, 0.5) + + detections = [] + for i in range(len(result_boxes)): + index = result_boxes[i] + box = boxes[index] + detection = { + 'class_id': class_ids[index], + 'class_name': CLASSES[class_ids[index]], + 'confidence': scores[index], + 'box': box, + 'scale': scale} + detections.append(detection) + + bounding_box( + original_image, + class_ids[index], + scores[index], + round(box[0] * scale), + round(box[1] * scale), + round((box[0] + box[2]) * scale), + round((box[1] + box[3]) * scale) + ) + + return detections + + +def draw_bbox_from_image(image_path, outputs): + image: np.ndarray = cv2.imread(image_path) + + detections = draw_bounding_box_yolo(image, outputs) + + cv2.imshow('image', image) + cv2.waitKey(0) + cv2.destroyAllWindows() + + print(detections) diff --git a/examples/yolov8.mojo b/examples/yolov8.mojo new file mode 100644 index 00000000..c23fd998 --- /dev/null +++ b/examples/yolov8.mojo @@ -0,0 +1,304 @@ +import basalt.nn as nn +from basalt import Tensor, TensorShape +from basalt import Graph, Symbol, OP, dtype +from basalt.autograd.attributes import AttributeVector, Attribute +from basalt.utils.tensor_creation_utils import to_tensor, to_numpy + +from python import Python +from math import ceil +from utils.static_tuple import StaticTuple + + +fn Conv( + inout g: Graph, + x: Symbol, + out_channels: Int, + kernel_size: Int, + padding: Int, + stride: Int, +) -> Symbol: + # NOTE: This is functionally equivalent to the Conv2D -> BatchNorm2D (removed in graph) -> SiLU (According to ONNX) + var conv = nn.Conv2d(g, x, out_channels, kernel_size, padding, stride) + var sigmoid = g.op(OP.SIGMOID, conv) + return g.op(OP.MUL, conv, sigmoid) + + +fn Conv( + inout g: Graph, + x: Symbol, + weight: Symbol, + bias: Symbol, + kernel_size: StaticIntTuple[2], + padding: StaticIntTuple[2], + stride: StaticIntTuple[2], +) -> Symbol: + # NOTE: This is functionally equivalent to the Conv2D -> BatchNorm2D (removed in graph) -> SiLU (According to ONNX) + var conv = g.op(OP.CONV2D, x, weight, bias, attributes=AttributeVector( + Attribute("padding", padding), + Attribute("stride", stride), + Attribute("dilation", StaticIntTuple[2](1, 1)), + )) + var sigmoid = g.op(OP.SIGMOID, conv) + return g.op(OP.MUL, conv, sigmoid) + + +fn C2f( + inout g: Graph, + x: Symbol, + out_channels: Int, + n: Int, + shortcut: Bool +) -> Symbol: + var conv = Conv(g, x, out_channels, 1, 0, 1) + + var split_size = out_channels // 2 + var split_sections = List[Int](split_size, split_size) + var split = g.split(conv, split_sections, dim=1) + + # declare the weights for the last conv here because that is the order in onnx file + var n_temp = 1 + if n > 1: + n_temp = 2 + var weight = g.param(TensorShape(out_channels, split_size * (n + 2), 1, 1)) + var bias = g.param(TensorShape(out_channels)) + + @parameter + fn bottleneck( + x: Symbol, out_channels: Int, shortcut: Bool = False + ) -> Symbol: + var conv1 = Conv(g, x, out_channels, 3, 1, 1) + var conv2 = Conv(g, conv1, out_channels, 3, 1, 1) + + if shortcut: + return g.op(OP.ADD, x, conv2) + else: + return conv2 + + var y1 = bottleneck(split[1], split_size, shortcut) + var y2 = y1 + + var concat_list = List[Symbol]() # add ability to concat to receive a list, becauase the the concatenation has to be done for each bottleneck layer that was run + + # NOTE: This assumes n >= 1 (Could add a constrained for it later) + for i in range(1, n): + y2 = bottleneck(y2, split_size, shortcut) + # concat_list.append(y2) + + # add ability to concat to receive a list, becauase the the concatenation has to be done for each bottleneck layer that was run + var y: Symbol + if n > 1: + y = g.concat(split[0], split[1], y1, y2, dim=1) + else: + y = g.concat(split[0], split[1], y1, dim=1) + + return Conv(g, y, weight, bias, 1, 0, 1) + + +fn SPPF(inout g: Graph, x: Symbol, out_channels: Int) -> Symbol: + var conv = Conv(g, x, out_channels // 2, 1, 0, 1) + + var maxpool2d_1 = nn.MaxPool2d(g, conv, kernel_size=5, stride=StaticIntTuple[2](1), padding=2) + var maxpool2d_2 = nn.MaxPool2d(g, maxpool2d_1, kernel_size=5, stride=StaticIntTuple[2](1), padding=2) + var maxpool2d_3 = nn.MaxPool2d(g, maxpool2d_2, kernel_size=5, stride=StaticIntTuple[2](1), padding=2) + + var y = g.concat(conv, maxpool2d_1, maxpool2d_2, maxpool2d_3, dim=1) + + return Conv(g, y, out_channels, 1, 0, 1) + + +fn Detect(inout g: Graph, x: Symbol, out_channels: Int, nc: Int, detect_conv: Int) -> Symbol: + # self.nc = nc # number of classes + # self.nl = len(ch) # number of detection layers + # self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x) + # self.no = nc + self.reg_max * 4 # number of outputs per anchor + + var reg_max = 16 + + var c2 = max(max(16, out_channels // 4), reg_max * 4) + var c3 = max(0, nc) # channels + + if detect_conv == 1: + var conv1 = Conv(g, x, c2, 3, 1, 1) + var conv1_2 = Conv(g, conv1, c2, 3, 1, 1) + var conv1_3 = nn.Conv2d(g, conv1_2, 4 * reg_max, 1, 0, 1) + + return conv1_3 + else: + var conv2 = Conv(g, x, c3, 3, 1, 1) + var conv2_2 = Conv(g, conv2, c3, 3, 1, 1) + var conv2_3 = nn.Conv2d(g, conv2_2, nc, 1, 0, 1) + + return conv2_3 + + +fn YoloV8(batch_size: Int, yolo_model_type: StaticTuple[Float64, 3]) -> Graph: + var g = Graph() + var x = g.input(TensorShape(batch_size, 3, 640, 640)) + + # Adapted from https://private-user-images.githubusercontent.com/27466624/239739723-57391d0f-1848-4388-9f30-88c2fb79233f.jpg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTUxMTk0MDYsIm5iZiI6MTcxNTExOTEwNiwicGF0aCI6Ii8yNzQ2NjYyNC8yMzk3Mzk3MjMtNTczOTFkMGYtMTg0OC00Mzg4LTlmMzAtODhjMmZiNzkyMzNmLmpwZz9Y>LUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA1MDclMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNTA3VDIxNTgyNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTNlZTdkY2ZiMDA0Y2VlOGZkYjllN2FkYTQ1MTY5OWY1YzYwNjIxZDM4OTZiYWRiMGU5YWQxNzkyMTcwNGNmNTQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.0ocPCiokkivvk95bQCds6Nt0EblUrHZElycV311ImF4. Some values (output_channels, stride, etc..) are different in the onnx file and the graph image. + + # Backbone + var out_channels_1 = int(64 * yolo_model_type[1]) + var conv_1 = Conv(g, x, out_channels_1, 3, 1, 2) + var out_channels_2 = int(128 * yolo_model_type[1]) + var conv_2 = Conv(g, conv_1, out_channels_2, 3, 1, 2) + var C2F_n_1 = int((3 * yolo_model_type[0]) + 1) # ceil + var C2f_1 = C2f(g, conv_2, out_channels_2, n=C2F_n_1, shortcut=True) + var out_channels_3 = int(256 * yolo_model_type[1]) + var conv_3 = Conv(g, C2f_1, out_channels_3, 3, 1, 2) + var C2F_n_2 = int((6 * yolo_model_type[0]) + 1) # ceil + var C2f_2 = C2f(g, conv_3, out_channels_3, n=C2F_n_2, shortcut=True) + + var out_channels_4 = int(512 * yolo_model_type[1]) + var conv_4 = Conv(g, C2f_2, out_channels_4, 3, 1, 2) + var C2f_3 = C2f(g, conv_4, out_channels_4, n=C2F_n_2, shortcut=True) + + var out_channels_5 = int(512 * yolo_model_type[1] * yolo_model_type[2]) + var conv_5 = Conv(g, C2f_3, out_channels_5, 3, 1, 2) + var C2f_4 = C2f(g, conv_5, out_channels_5, n=C2F_n_1, shortcut=True) + var SPPF_1 = SPPF(g, C2f_4, out_channels_5) + + # Head + var upsample_1 = g.op(OP.UPSAMPLE, SPPF_1, attributes=AttributeVector(Attribute("mode", "nearest"), Attribute("scales", TensorShape(2, 2)))) + + # The order of concats was wrong + var concat_1 = g.concat(upsample_1, C2f_3, dim=1) + + var out_channels_6 = int(512 * yolo_model_type[1]) + var C2f_5 = C2f(g, concat_1, out_channels_6, n=C2F_n_1, shortcut=False) + + var upsample_2 = g.op(OP.UPSAMPLE, C2f_5, attributes=AttributeVector(Attribute("mode", "nearest"), Attribute("scales", TensorShape(2, 2)))) + + var concat_2 = g.concat(upsample_2, C2f_2, dim=1) + + var out_channels_7 = int(256 * yolo_model_type[1]) + var C2f_6 = C2f(g, concat_2, out_channels_7, n=C2F_n_1, shortcut=False) + + var conv_6 = Conv(g, C2f_6, out_channels_7, 3, 1, 2) + var concat_3 = g.concat(conv_6, C2f_5, dim=1) + var C2f_7 = C2f(g, concat_3, out_channels_6, n=C2F_n_1, shortcut=False) + + var conv_7 = Conv(g, C2f_7, out_channels_6, 3, 1, 2) + var concat_4 = g.concat(conv_7, SPPF_1, dim=1) + var out_channels_8 = int(512 * yolo_model_type[1] * yolo_model_type[2]) + var C2f_8 = C2f(g, concat_4, out_channels_8, n=C2F_n_1, shortcut=False) + + # Detect + # declare them this way because the order of initializers in the onnx file is like this + var detect_1 = Detect(g, C2f_6, out_channels_7, 80, 1) + var detect_2 = Detect(g, C2f_7, out_channels_6, 80, 1) + var detect_3 = Detect(g, C2f_8, out_channels_8, 80, 1) + + var detect_1_1 = Detect(g, C2f_6, out_channels_7, 80, 2) + var detect_2_1 = Detect(g, C2f_7, out_channels_6, 80, 2) + var detect_3_1 = Detect(g, C2f_8, out_channels_8, 80, 2) + + var concat_detect_1 = g.concat(detect_1, detect_1_1, dim=1) + var concat_detect_2 = g.concat(detect_2, detect_2_1, dim=1) + var concat_detect_3 = g.concat(detect_3, detect_3_1, dim=1) + + # -------- output + var reshape_1 = g.op(OP.RESHAPE, concat_detect_1, attributes=AttributeVector(Attribute("shape", TensorShape(1, 144, concat_detect_1.shape[2] * concat_detect_1.shape[3])))) + + var reshape_2 = g.op(OP.RESHAPE, concat_detect_2, attributes=AttributeVector(Attribute("shape", TensorShape(1, 144, concat_detect_2.shape[2] * concat_detect_2.shape[3])))) + + var reshape_3 = g.op(OP.RESHAPE, concat_detect_3, attributes=AttributeVector(Attribute("shape", TensorShape(1, 144, concat_detect_3.shape[2] * concat_detect_3.shape[3])))) + + # -- + + var concat_5 = g.concat(reshape_1, reshape_2, reshape_3, dim=2) + + var split_sections = List[Int](64, 80) + var split_1 = g.split(concat_5, split_sections, dim=1) + + var for_second_concat = g.op(OP.SIGMOID, split_1[1]) + + var reshape_4 = g.op(OP.RESHAPE, split_1[0], attributes=AttributeVector(Attribute("shape", TensorShape(1, 4, 16, 8400)))) + + var transpose_1 = g.op(OP.TRANSPOSE, reshape_4, attributes=AttributeVector(Attribute("axes", List[Int](0, 2, 1, 3)))) + + var softmax = nn.Softmax(g, transpose_1, axis=1) + + var conv_norm_1 = nn.Conv2d(g, softmax, 1, 1, 0, 1, 1) + + var reshape_5 = g.op(OP.RESHAPE, conv_norm_1, attributes=AttributeVector(Attribute("shape", TensorShape(1, 4, 8400)))) + + var slice_1 = g.op(OP.SLICE, reshape_5, attributes=AttributeVector( + Attribute("axes", List[Int](1)), + Attribute("starts", List[Int](0)), + Attribute("ends", List[Int](2)))) + var slice_2 = g.op(OP.SLICE, reshape_5, attributes=AttributeVector( + Attribute("axes", List[Int](1)), + Attribute("starts", List[Int](2)), + Attribute("ends", List[Int](4)))) + + var sub_constant_value = g.input(TensorShape(1, 2, 8400)) + var sub_with_constant_1 = g.op(OP.SUB, sub_constant_value, slice_1) + var add_constant_value = g.input(TensorShape(1, 2, 8400)) + var add_with_constant_2 = g.op(OP.ADD, add_constant_value, slice_2) + + var add_1 = g.op(OP.ADD, sub_with_constant_1, add_with_constant_2) + var sub_1 = g.op(OP.SUB, add_with_constant_2, sub_with_constant_1) + + var div_1 = g.op(OP.DIV, add_1, 2) + + var concat_6 = g.concat(div_1, sub_1, dim=1) + + var mul_constant_value = g.input(TensorShape(1, 8400)) + var mul_with_constant_1 = g.op(OP.MUL, concat_6, mul_constant_value) + + var concat_7 = g.concat(mul_with_constant_1, for_second_concat, dim=1) + + g.out(concat_7) + + return g ^ + + +alias yolov8_n = StaticTuple[Float64, 3]( + 0.33, 0.25, 2 +) # d (depth_multiplier), w (width_multiplier), r (ratio) +# var yolov8_s +# var yolov8_m + + +fn get_constant_values_from_onnx_model(model_path: String) raises -> List[Tensor[dtype]]: + var onnx = Python.import_module("onnx") + + var model = onnx.load(model_path) + + var result = List[Tensor[dtype]]() + + for node in model.graph.node: + if node.op_type == "Constant": + for attr in node.attribute: + if attr.name == 'value': + var tensor = onnx.numpy_helper.to_array(attr.t) + if node.name == "/model.22/Constant_9": + result.append(to_tensor(tensor)) + if node.name == "/model.22/Constant_10": + result.append(to_tensor(tensor)) + if node.name == "/model.22/Constant_12": + result.append(to_tensor(tensor)) + + return result + + +fn main() raises: + alias graph = YoloV8(1, yolov8_n) + var model = nn.Model[graph]() + + # try: graph.render("node") + # except: print("Could not render graph") + + + model.load_model_data("./examples/data/yolov8n.onnx") + + var constant_values = get_constant_values_from_onnx_model("./examples/data/yolov8n.onnx") + + Python.add_to_path("./examples") + var yolo_utils = Python.import_module("yolo_v8_utils") + var image_tensor = to_tensor(yolo_utils.get_model_input('./examples/data/bus.jpg')) + + var res = model.inference(image_tensor, constant_values[0], constant_values[1], constant_values[2]) + + yolo_utils.draw_bbox_from_image("./examples/data/bus.jpg", to_numpy(res[0])) \ No newline at end of file diff --git a/examples/yolov8_cam.mojo b/examples/yolov8_cam.mojo new file mode 100644 index 00000000..2c0cb3d8 --- /dev/null +++ b/examples/yolov8_cam.mojo @@ -0,0 +1,85 @@ +import sys +from time.time import now +from python.python import Python +from utils.static_tuple import StaticTuple + +from yolov8 import YoloV8, get_constant_values_from_onnx_model + +import basalt.nn as nn +from basalt import Tensor, TensorShape, dtype +from basalt.utils.tensor_creation_utils import to_tensor, to_numpy + + +fn cam( + inout model: nn.Model, + constants: List[Tensor[dtype]] +) raises: + + Python.add_to_path("./examples") + var yolo_utils = Python.import_module("yolo_v8_utils") + + var cv2 = Python.import_module("cv2") + var np = Python.import_module("numpy") + var cap = cv2.VideoCapture(0) + + if not cap.isOpened(): + print("Error: Could not open webcam") + sys.exit(1) + + var height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT).to_float64() + var width = cap.get(cv2.CAP_PROP_FRAME_WIDTH).to_float64() + var length = max(height, width) + var pads = np.array([0, length - height, 0, length - width, 0, 0], dtype=np.int32).reshape(3, 2) + + var last_time = now() + + while True: + var r = cap.read() + + if not r[0]: + print("Error: Could not read frame") + break + + var image = np.pad(r[1], pads, mode='constant', constant_values=0) + var blob = cv2.dnn.blobFromImage(image, scalefactor=1/255, size=(640, 640), swapRB=True) + + var res = model.inference(to_tensor(blob), constants[0], constants[1], constants[2]) + + yolo_utils.draw_bounding_box_yolo(r[1], to_numpy(res[0])) + cv2.imshow( + 'Basalt', + cv2.putText( + r[1], + "FPS: " + String(1e9 / (now() - last_time)), + (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.30, (0, 0, 0), 1, cv2.LINE_AA + ) + ) + + last_time = now() + if int(cv2.waitKey(1) & 0xFF) == 27 or cv2.getWindowProperty('Basalt', cv2.WND_PROP_VISIBLE) < 1: + cv2.destroyAllWindows() + sys.exit() + + +fn main(): + + alias yolov8_n = StaticTuple[Float64, 3]( + 0.33, 0.25, 2 + ) # d (depth_multiplier), w (width_multiplier), r (ratio) + + alias graph = YoloV8(1, yolov8_n) + var model = nn.Model[graph]() + + model.load_model_data("./examples/data/yolov8n.onnx") + + try: + var constant_values = get_constant_values_from_onnx_model("./examples/data/yolov8n.onnx") + + cam( + model, + constant_values + ) + + except e: + print("Error in cam() function") + print(e) diff --git a/examples/yolov8n_onnx.py b/examples/yolov8n_onnx.py new file mode 100644 index 00000000..5b4eb2de --- /dev/null +++ b/examples/yolov8n_onnx.py @@ -0,0 +1,22 @@ +import cv2.dnn +import numpy as np +import onnxruntime as rt + +from yolo_v8_utils import get_model_input, draw_bbox_from_image + + +def main(onnx_model, input_image): + blob = get_model_input(input_image) + + model: cv2.dnn.Net = cv2.dnn.readNetFromONNX(onnx_model) + model = rt.InferenceSession(onnx_model) + + outputs = model.run(None, {"images": blob})[0] + + draw_bbox_from_image(input_image, outputs) + + +main( + onnx_model='examples/data/yolov8n.onnx', + input_image="examples/data/bus.jpg" +) \ No newline at end of file diff --git a/profile.sh b/profile.sh deleted file mode 100755 index 870754f9..00000000 --- a/profile.sh +++ /dev/null @@ -1,118 +0,0 @@ -#!/bin/bash - -function profile() { - if [ ! -d ~/FlameGraph ]; then - InstallFlameGraph - fi - - if [ -f /proc/sys/fs/binfmt_misc/WSLInterop ]; then - profileLinux "$1" - else - case "$OSTYPE" in - darwin*) - profileMac "$1" - ;; - linux-gnu*|msys) - profileLinux "$1" - ;; - esac - fi -} - -function profileLinux() { - local mojo_file=$1 - LinuxInstallDependencies - LinuxPermissions - runProfile "$mojo_file" -} - -function profileMac() { - local mojo_file=$1 - MacInstallDependencies - MacPermissions - runProfile "$mojo_file" -} - -function runProfile() { - local mojo_file=$1 - local mojo_name="${mojo_file%.mojo}" - local temp_dir="./temp" - local perf_output="$temp_dir/out.perf" - local flamegraph_output="flamegraph.svg" - - echo "Profiling $mojo_file..." - - mkdir -p "$temp_dir" - - echo "Building $mojo_file..." - mojo build -I . "$mojo_file" - - echo "Stripping debug symbols..." - mv "$mojo_name" "$temp_dir/run.exe" - llvm-strip --strip-debug "$temp_dir/run.exe" - - echo "Running perf record..." - sudo perf record -F 99 -a -g -o "$perf_output" -- "$temp_dir/run.exe" - - echo "Generating flamegraph..." - sudo perf script -i "$perf_output" | ~/FlameGraph/stackcollapse-perf.pl | ~/FlameGraph/flamegraph.pl > "$flamegraph_output" - - echo "Opening flamegraph: $flamegraph_output" - - if command -v open &> /dev/null; then - open "$flamegraph_output" - elif command -v explorer.exe &> /dev/null; then - explorer.exe "$flamegraph_output" - elif command -v google-chrome &> /dev/null; then - google-chrome "$flamegraph_output" - fi - - echo "Cleaning up temporary files..." - rm -rf "$temp_dir" - - echo "Profiling completed." -} - -function LinuxInstallDependencies() { - if ! command -v perf &> /dev/null; then - echo "Installing perf for Linux/WSL" - sudo apt-get update - sudo apt-get install -y linux-tools-common linux-tools-generic - fi - - if ! command -v llvm-strip &> /dev/null; then - echo "Installing LLVM for Linux/WSL" - sudo apt-get install -y llvm - fi -} - -function MacInstallDependencies() { - if ! command -v perf &> /dev/null; then - echo "Installing perf for Mac" - brew install perf - fi - - if ! command -v llvm-strip &> /dev/null; then - echo "Installing LLVM for Mac" - brew install llvm - fi -} - -function InstallFlameGraph() { - echo "Installing FlameGraph" - git clone https://github.com/brendangregg/FlameGraph.git - mv FlameGraph ~/FlameGraph -} - -function LinuxPermissions() { - echo "Setting Linux/WSL permissions" - echo 0 | sudo tee /proc/sys/kernel/kptr_restrict > /dev/null - echo -1 | sudo tee /proc/sys/kernel/perf_event_paranoid > /dev/null - sudo sysctl -p > /dev/null -} - -function MacPermissions() { - echo "Setting Mac permissions" -} - -profile "$1" diff --git a/tests/mojo/test_activations.mojo b/tests/mojo/test_activations.mojo index 200215d9..f2db8e91 100644 --- a/tests/mojo/test_activations.mojo +++ b/tests/mojo/test_activations.mojo @@ -8,6 +8,7 @@ from basalt.nn import ( Softmax, LogSoftmax, ReLU, + LeakyReLU, Sigmoid, Tanh, ) @@ -19,6 +20,9 @@ from tests import assert_tensors_equal alias Activation = fn (inout g: Graph, input: Symbol) -> Symbol alias AxisActivation = fn (inout g: Graph, input: Symbol, axis: Int) -> Symbol +alias LeakyReLUActivation = fn ( + inout g: Graph, input: Symbol, negative_slope: Scalar[dtype] +) -> Symbol fn create_graph[ @@ -30,7 +34,19 @@ fn create_graph[ var x = g.input(shape) var activation = func(g, x, axis) g.out(activation) - return g ^ + return g^ + + +fn create_graph[ + shape: TensorShape, + func: LeakyReLUActivation, + negative_slope: Scalar[dtype], +]() -> Graph: + var g = Graph() + var x = g.input(shape) + var activation = func(g, x, negative_slope) + g.out(activation) + return g^ fn create_graph[shape: TensorShape, func: Activation]() -> Graph: @@ -38,7 +54,7 @@ fn create_graph[shape: TensorShape, func: Activation]() -> Graph: var x = g.input(shape) var activation = func(g, x) g.out(activation) - return g ^ + return g^ fn test_graph[ @@ -56,6 +72,22 @@ fn test_graph[ assert_equal(len(graph.nodes), nodes) +fn test_graph[ + shape: TensorShape, + func: LeakyReLUActivation, + nodes: Int, + negative_slope: Scalar[dtype], +](input: Tensor[dtype], expected: Tensor[dtype]) raises: + alias graph = create_graph[shape, func, negative_slope]() + + var model = Model[graph](inference_only=True) + var res = model.inference(input)[0] + + assert_tensors_equal["almost"](res, expected) + assert_equal(len(graph.nodes), nodes) + + +# TODO: All these overloads feel redundant. Find a way to condense them fn test_graph[ shape: TensorShape, func: Activation, @@ -125,6 +157,25 @@ fn test_RELU() raises: test_graph[shape, ReLU, nodes](input, expected) +fn test_LEAKYRELU() raises: + alias negative_slope = 0.1 + + alias shape = TensorShape(2, 3) + alias nodes = 1 + + var input = Tensor[dtype](shape) + + for i in range(6): + input[i] = i - 3 + + var expected = Tensor[dtype](shape) + + for i in range(6): + expected[i] = i - 3 if i - 3 > 0 else negative_slope * (i - 3) + + test_graph[shape, LeakyReLU, nodes, negative_slope](input, expected) + + fn test_SIGMOID() raises: alias shape = TensorShape(2, 3) alias nodes = 1 @@ -156,6 +207,7 @@ fn main(): test_SOFTMAX() test_LOGSOFTMAX() test_RELU() + test_LEAKYRELU() test_SIGMOID() test_TANH() except e: diff --git a/tests/mojo/test_backward.mojo b/tests/mojo/test_backward.mojo index 167a232b..d8acc45e 100644 --- a/tests/mojo/test_backward.mojo +++ b/tests/mojo/test_backward.mojo @@ -157,10 +157,15 @@ fn test_POW() raises: fill(temp, (2**2) * log[dtype, 1](2)) expected_grad2[0] = tsum(temp) - test_binary_op_backward[OP.POW, t1_shape, t2_shape, ug_shape]( - t1, t2, ug, expected_grad1, expected_grad2 - ) + test_binary_op_backward[OP.POW, t1_shape, t2_shape, ug_shape](t1, t2, ug, expected_grad1, expected_grad2) + + fill(t1, 0.0) + fill(t2, 0) + fill(ug, 1.0) + fill(expected_grad1, 0.0) + fill(expected_grad2, 0.0) + test_binary_op_backward[OP.POW, t1_shape, t2_shape, ug_shape](t1, t2, ug, expected_grad1, expected_grad2) fn test_SUM() raises: alias t1_shape = TensorShape(2, 3) diff --git a/tests/mojo/test_mlops.mojo b/tests/mojo/test_mlops.mojo index 2ba723e6..07bdde58 100644 --- a/tests/mojo/test_mlops.mojo +++ b/tests/mojo/test_mlops.mojo @@ -1,11 +1,24 @@ from basalt import dtype, nelts from basalt.autograd import OP from basalt.autograd.attributes import AttributeVector, Attribute -from basalt.autograd.ops.mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE +from basalt.autograd.ops.mlops import ( + SIGMOID, + RELU, + LEAKYRELU, + TANH, + CLIP, + SQUEEZE, + UNSQUEEZE, +) from basalt.nn import Tensor, TensorShape from basalt.utils.tensorutils import fill -from tests import assert_tensors_equal, test_unary_op, test_unary_op_backward, to_numpy +from tests import ( + assert_tensors_equal, + test_unary_op, + test_unary_op_backward, + to_numpy, +) fn test_SIGMOID() raises: @@ -30,7 +43,9 @@ fn test_backward_SIGMOID() raises: expected_grad, 5.0 * 0.25 ) # 0.25 = d(sigmoid(0))/dx = sigmoid(0) * (1 - sigmoid(0)) - test_unary_op_backward[OP.SIGMOID, t1_shape, ug_shape](t1, ug, expected_grad) + test_unary_op_backward[OP.SIGMOID, t1_shape, ug_shape]( + t1, ug, expected_grad + ) fn test_RELU() raises: @@ -71,6 +86,53 @@ fn test_backward_RELU() raises: test_unary_op_backward[OP.RELU, t1_shape, ug_shape](t1, ug, expected_grad) +fn test_LEAKYRELU() raises: + alias t1_shape = TensorShape(2, 3) + var t1: Tensor[dtype] = Tensor[dtype](t1_shape) + # TODO: When tensors can do slices, this could be changed to two fill functions. + for i in range(3): + t1[i] = 3 + for i in range(3, 6): + t1[i] = -3 + + var expected = Tensor[dtype](2, 3) + for i in range(3): + expected[i] = 3 + for i in range(3, 6): + expected[i] = -0.3 + + test_unary_op[ + OP.LEAKYRELU, + t1_shape, + AttributeVector(Attribute("negative_slope", 0.1)), + ](t1, expected) + + +fn test_backward_LEAKYRELU() raises: + alias t1_shape = TensorShape(2, 3) + alias ug_shape = TensorShape(2, 3) + var t1: Tensor[dtype] = Tensor[dtype](t1_shape) + var ug: Tensor[dtype] = Tensor[dtype](ug_shape) + for i in range(3): + t1[i] = 3 + for i in range(3, 6): + t1[i] = -3 + fill(ug, 5.0) + + var expected_grad = Tensor[dtype](2, 3) + for i in range(3): + expected_grad[i] = 1 * 5.0 + for i in range(3, 6): + expected_grad[i] = 0.1 * 5.0 + + test_unary_op_backward[ + OP.LEAKYRELU, + t1_shape, + ug_shape, + AttributeVector(Attribute("negative_slope", 0.1)), + ](t1, ug, expected_grad) + + fn test_TANH() raises: alias t1_shape = TensorShape(2, 3) var t1: Tensor[dtype] = Tensor[dtype](t1_shape) @@ -110,7 +172,9 @@ fn test_CLIP() raises: for i in range(6): var val = Scalar[dtype](i - 3) expected_min[i] = val if (val > -1.1) else -1.1 - test_unary_op[OP.CLIP, t1_shape, AttributeVector(min_attr)](t1, expected_min) + test_unary_op[OP.CLIP, t1_shape, AttributeVector(min_attr)]( + t1, expected_min + ) # Clip with max alias max_attr = Attribute("max", 1.1) @@ -118,7 +182,9 @@ fn test_CLIP() raises: for i in range(6): var val = Scalar[dtype](i - 3) expected_max[i] = val if (val < 1.1) else 1.1 - test_unary_op[OP.CLIP, t1_shape, AttributeVector(max_attr)](t1, expected_max) + test_unary_op[OP.CLIP, t1_shape, AttributeVector(max_attr)]( + t1, expected_max + ) # Clip with min and max var expected = Tensor[dtype](2, 3) @@ -130,7 +196,9 @@ fn test_CLIP() raises: expected[i] = 1.1 else: expected[i] = val - test_unary_op[OP.CLIP, t1_shape, AttributeVector(min_attr, max_attr)](t1, expected) + test_unary_op[OP.CLIP, t1_shape, AttributeVector(min_attr, max_attr)]( + t1, expected + ) fn test_backward_CLIP() raises: @@ -152,7 +220,9 @@ fn test_backward_CLIP() raises: for i in range(6): var val = Scalar[dtype](i - 3) expected_min[i] = 5.0 if (val > -1.1) else 0.0 - test_unary_op_backward[OP.CLIP, t1_shape, ug_shape, min_attr](t1, ug, expected_min) + test_unary_op_backward[OP.CLIP, t1_shape, ug_shape, min_attr]( + t1, ug, expected_min + ) # Clip with max alias max_attr = AttributeVector(Attribute("max", 1.1)) @@ -160,7 +230,9 @@ fn test_backward_CLIP() raises: for i in range(6): var val = Scalar[dtype](i - 3) expected_max[i] = 5.0 if (val < 1.1) else 0.0 - test_unary_op_backward[OP.CLIP, t1_shape, ug_shape, max_attr](t1, ug, expected_max) + test_unary_op_backward[OP.CLIP, t1_shape, ug_shape, max_attr]( + t1, ug, expected_max + ) # Clip with min and max alias attrs = AttributeVector(Attribute("min", -1.1), Attribute("max", 1.1)) @@ -201,7 +273,9 @@ fn test_SQUEEZE() raises: expected = Tensor[dtype](1, 2, 3) fill(expected, 5.0) test_unary_op[ - OP.SQUEEZE, t1_shape, AttributeVector(Attribute("dims", TensorShape(2, 4))) + OP.SQUEEZE, + t1_shape, + AttributeVector(Attribute("dims", TensorShape(2, 4))), ](t1, expected) @@ -216,7 +290,9 @@ fn test_backward_SQUEEZE() raises: var expected_grad = Tensor[dtype](2, 1, 3, 1) fill(expected_grad, 5.0) - test_unary_op_backward[OP.SQUEEZE, t1_shape, ug_shape](t1, ug, expected_grad) + test_unary_op_backward[OP.SQUEEZE, t1_shape, ug_shape]( + t1, ug, expected_grad + ) fn test_UNSQUEEZE() raises: @@ -228,26 +304,34 @@ fn test_UNSQUEEZE() raises: var expected = Tensor[dtype](2, 1, 3, 1) fill(expected, 5.0) test_unary_op[ - OP.UNSQUEEZE, t1_shape, AttributeVector(Attribute("dims", TensorShape(1, 3))) + OP.UNSQUEEZE, + t1_shape, + AttributeVector(Attribute("dims", TensorShape(1, 3))), ](t1, expected) expected = Tensor[dtype](2, 1, 3) fill(expected, 5.0) test_unary_op[ - OP.UNSQUEEZE, t1_shape, AttributeVector(Attribute("dims", TensorShape(1))) + OP.UNSQUEEZE, + t1_shape, + AttributeVector(Attribute("dims", TensorShape(1))), ](t1, expected) expected = Tensor[dtype](1, 2, 3) fill(expected, 5.0) test_unary_op[ - OP.UNSQUEEZE, t1_shape, AttributeVector(Attribute("dims", TensorShape(-3))) + OP.UNSQUEEZE, + t1_shape, + AttributeVector(Attribute("dims", TensorShape(-3))), ](t1, expected) expected = Tensor[dtype](2, 1, 3, 1) fill(expected, 5.0) test_unary_op[ - OP.UNSQUEEZE, t1_shape, AttributeVector(Attribute("dims", TensorShape(-1, -3))) + OP.UNSQUEEZE, + t1_shape, + AttributeVector(Attribute("dims", TensorShape(-1, -3))), ](t1, expected) @@ -262,7 +346,9 @@ fn test_backward_UNSQUEEZE() raises: var expected_grad = Tensor[dtype](2, 3) fill(expected_grad, 5.0) - test_unary_op_backward[OP.UNSQUEEZE, t1_shape, ug_shape](t1, ug, expected_grad) + test_unary_op_backward[OP.UNSQUEEZE, t1_shape, ug_shape]( + t1, ug, expected_grad + ) fn test_SLICE() raises: @@ -270,7 +356,7 @@ fn test_SLICE() raises: var t1: Tensor[dtype] = Tensor[dtype](t1_shape) for i in range(t1.num_elements()): t1[i] = i - + alias slice = Slice(1, 3, 1) # dim = 0 @@ -278,15 +364,17 @@ fn test_SLICE() raises: for i in range(2): for j in range(4): for k in range(5): - expected_0[i*4*5 + j*5 + k] = (i + 1) * 4 * 5 + j * 5 + k + expected_0[i * 4 * 5 + j * 5 + k] = (i + 1) * 4 * 5 + j * 5 + k test_unary_op[ - OP.SLICE, t1_shape, AttributeVector( + OP.SLICE, + t1_shape, + AttributeVector( Attribute("starts", TensorShape(slice.start)), Attribute("ends", TensorShape(slice.end)), Attribute("steps", TensorShape(slice.step)), - Attribute("axes", TensorShape(0)) - ) + Attribute("axes", TensorShape(0)), + ), ](t1, expected_0) # dim = 1 @@ -294,15 +382,17 @@ fn test_SLICE() raises: for i in range(3): for j in range(2): for k in range(5): - expected_1[i*2*5 + j*5 + k] = i * 4 * 5 + (j + 1) * 5 + k + expected_1[i * 2 * 5 + j * 5 + k] = i * 4 * 5 + (j + 1) * 5 + k test_unary_op[ - OP.SLICE, t1_shape, AttributeVector( + OP.SLICE, + t1_shape, + AttributeVector( Attribute("starts", TensorShape(slice.start)), Attribute("ends", TensorShape(slice.end)), Attribute("steps", TensorShape(slice.step)), - Attribute("axes", TensorShape(1)) - ) + Attribute("axes", TensorShape(1)), + ), ](t1, expected_1) # dim = 2 @@ -310,15 +400,17 @@ fn test_SLICE() raises: for i in range(3): for j in range(4): for k in range(2): - expected_2[i*4*2 + j*2 + k] = i * 4 * 5 + j * 5 + (k + 1) - + expected_2[i * 4 * 2 + j * 2 + k] = i * 4 * 5 + j * 5 + (k + 1) + test_unary_op[ - OP.SLICE, t1_shape, AttributeVector( + OP.SLICE, + t1_shape, + AttributeVector( Attribute("starts", TensorShape(slice.start)), Attribute("ends", TensorShape(slice.end)), Attribute("steps", TensorShape(slice.step)), - Attribute("axes", TensorShape(2)) - ) + Attribute("axes", TensorShape(2)), + ), ](t1, expected_2) @@ -335,15 +427,19 @@ fn test_SLICE_step() raises: for i in range(3): for j in range(2): for k in range(2): - expected_0[i*2*2 + j*2 + k] = (i*2 + 1) * 2 * 2 + j * 2 + k + expected_0[i * 2 * 2 + j * 2 + k] = ( + (i * 2 + 1) * 2 * 2 + j * 2 + k + ) test_unary_op[ - OP.SLICE, t0_shape, AttributeVector( + OP.SLICE, + t0_shape, + AttributeVector( Attribute("starts", TensorShape(slice.start)), Attribute("ends", TensorShape(slice.end)), Attribute("steps", TensorShape(slice.step)), - Attribute("axes", TensorShape(0)) - ) + Attribute("axes", TensorShape(0)), + ), ](t0, expected_0) # dim = 1 @@ -356,15 +452,19 @@ fn test_SLICE_step() raises: for i in range(2): for j in range(3): for k in range(2): - expected_1[i*3*2 + j*2 + k] = i * 10 * 2 + (j*2 + 1) * 2 + k + expected_1[i * 3 * 2 + j * 2 + k] = ( + i * 10 * 2 + (j * 2 + 1) * 2 + k + ) test_unary_op[ - OP.SLICE, t1_shape, AttributeVector( + OP.SLICE, + t1_shape, + AttributeVector( Attribute("starts", TensorShape(slice.start)), Attribute("ends", TensorShape(slice.end)), Attribute("steps", TensorShape(slice.step)), - Attribute("axes", TensorShape(1)) - ) + Attribute("axes", TensorShape(1)), + ), ](t1, expected_1) # dim = 2 @@ -377,15 +477,19 @@ fn test_SLICE_step() raises: for i in range(2): for j in range(2): for k in range(3): - expected_2[i*2*3 + j*3 + k] = i * 2 * 10 + j * 10 + (k*2 + 1) + expected_2[i * 2 * 3 + j * 3 + k] = ( + i * 2 * 10 + j * 10 + (k * 2 + 1) + ) test_unary_op[ - OP.SLICE, t2_shape, AttributeVector( + OP.SLICE, + t2_shape, + AttributeVector( Attribute("starts", TensorShape(slice.start)), Attribute("ends", TensorShape(slice.end)), Attribute("steps", TensorShape(slice.step)), - Attribute("axes", TensorShape(2)) - ) + Attribute("axes", TensorShape(2)), + ), ](t2, expected_2) @@ -402,15 +506,19 @@ fn test_SLICE_neg() raises: for i in range(3): for j in range(2): for k in range(2): - expected_0[i*2*2 + j*2 + k] = StaticIntTuple[3](6, 4, 2)[i] * 2 * 2 + j * 2 + k + expected_0[i * 2 * 2 + j * 2 + k] = ( + StaticIntTuple[3](6, 4, 2)[i] * 2 * 2 + j * 2 + k + ) test_unary_op[ - OP.SLICE, t0_shape, AttributeVector( + OP.SLICE, + t0_shape, + AttributeVector( Attribute("starts", TensorShape(slice.start)), Attribute("ends", TensorShape(slice.end)), Attribute("steps", TensorShape(slice.step)), - Attribute("axes", TensorShape(0)) - ) + Attribute("axes", TensorShape(0)), + ), ](t0, expected_0) # dim = 1 @@ -423,15 +531,19 @@ fn test_SLICE_neg() raises: for i in range(2): for j in range(3): for k in range(2): - expected_1[i*3*2 + j*2 + k] = i * 10 * 2 + StaticIntTuple[3](6, 4, 2)[j] * 2 + k + expected_1[i * 3 * 2 + j * 2 + k] = ( + i * 10 * 2 + StaticIntTuple[3](6, 4, 2)[j] * 2 + k + ) test_unary_op[ - OP.SLICE, t1_shape, AttributeVector( + OP.SLICE, + t1_shape, + AttributeVector( Attribute("starts", TensorShape(slice.start)), Attribute("ends", TensorShape(slice.end)), Attribute("steps", TensorShape(slice.step)), - Attribute("axes", TensorShape(1)) - ) + Attribute("axes", TensorShape(1)), + ), ](t1, expected_1) # dim = 2 @@ -444,15 +556,19 @@ fn test_SLICE_neg() raises: for i in range(2): for j in range(2): for k in range(3): - expected_2[i*2*3 + j*3 + k] = i * 2 * 10 + j * 10 + StaticIntTuple[3](6, 4, 2)[k] + expected_2[i * 2 * 3 + j * 3 + k] = ( + i * 2 * 10 + j * 10 + StaticIntTuple[3](6, 4, 2)[k] + ) test_unary_op[ - OP.SLICE, t2_shape, AttributeVector( + OP.SLICE, + t2_shape, + AttributeVector( Attribute("starts", TensorShape(slice.start)), Attribute("ends", TensorShape(slice.end)), Attribute("steps", TensorShape(slice.step)), - Attribute("axes", TensorShape(2)) - ) + Attribute("axes", TensorShape(2)), + ), ](t2, expected_2) @@ -470,22 +586,35 @@ fn test_SLICE_multiple_axes() raises: for i in range(3): for j in range(3): for k in range(5): - expected[i*3*5 + j*5 + k] = StaticIntTuple[5](1, 3, 5, 7, 9)[i] * 32 * 40 + StaticIntTuple[3](3, 6, 9)[j] * 40 + StaticIntTuple[5](5, 7, 9, 11, 13)[k] - + expected[i * 3 * 5 + j * 5 + k] = ( + StaticIntTuple[5](1, 3, 5, 7, 9)[i] * 32 * 40 + + StaticIntTuple[3](3, 6, 9)[j] * 40 + + StaticIntTuple[5](5, 7, 9, 11, 13)[k] + ) + test_unary_op[ - OP.SLICE, t1_shape, AttributeVector( - Attribute("starts", TensorShape(slice_0.start, slice_1.start, slice_2.start)), - Attribute("ends", TensorShape(slice_0.end, slice_1.end, slice_2.end)), - Attribute("steps", TensorShape(slice_0.step, slice_1.step, slice_2.step)), + OP.SLICE, + t1_shape, + AttributeVector( + Attribute( + "starts", + TensorShape(slice_0.start, slice_1.start, slice_2.start), + ), + Attribute( + "ends", TensorShape(slice_0.end, slice_1.end, slice_2.end) + ), + Attribute( + "steps", TensorShape(slice_0.step, slice_1.step, slice_2.step) + ), # Attribute("axes", TensorShape(0, 1, 2)) - ) + ), ](t1, expected) alias t2_shape = TensorShape(20, 32, 40, 50) var t2: Tensor[dtype] = Tensor[dtype](t2_shape) for i in range(t2.num_elements()): t2[i] = i - + alias slice_2_1 = Slice(1, 6, 2) alias slice_2_2 = Slice(3, 10, 3) alias slice_2_3 = Slice(5, 15, 2) @@ -497,14 +626,42 @@ fn test_SLICE_multiple_axes() raises: for j in range(3): for k in range(5): for l in range(4): - expected_2[i*3*5*4 + j*5*4 + k*4 + l] = StaticIntTuple[5](1, 3, 5, 7, 9)[i] * 32 * 40 * 50 + StaticIntTuple[3](3, 6, 9)[j] * 40 * 50 + StaticIntTuple[5](5, 7, 9, 11, 13)[k] * 50 + StaticIntTuple[4](7, 11, 15, 19)[l] - + expected_2[i * 3 * 5 * 4 + j * 5 * 4 + k * 4 + l] = ( + StaticIntTuple[5](1, 3, 5, 7, 9)[i] * 32 * 40 * 50 + + StaticIntTuple[3](3, 6, 9)[j] * 40 * 50 + + StaticIntTuple[5](5, 7, 9, 11, 13)[k] * 50 + + StaticIntTuple[4](7, 11, 15, 19)[l] + ) + test_unary_op[ - OP.SLICE, t2_shape, AttributeVector( - Attribute("starts", TensorShape(slice_2_1.start, slice_2_2.start, slice_2_3.start, slice_2_4.start)), - Attribute("ends", TensorShape(slice_2_1.end, slice_2_2.end, slice_2_3.end, slice_2_4.end)), - Attribute("steps", TensorShape(slice_2_1.step, slice_2_2.step, slice_2_3.step, slice_2_4.step)), - ) + OP.SLICE, + t2_shape, + AttributeVector( + Attribute( + "starts", + TensorShape( + slice_2_1.start, + slice_2_2.start, + slice_2_3.start, + slice_2_4.start, + ), + ), + Attribute( + "ends", + TensorShape( + slice_2_1.end, slice_2_2.end, slice_2_3.end, slice_2_4.end + ), + ), + Attribute( + "steps", + TensorShape( + slice_2_1.step, + slice_2_2.step, + slice_2_3.step, + slice_2_4.step, + ), + ), + ), ](t2, expected_2) @@ -523,15 +680,18 @@ fn test_backward_SLICE() raises: for i in range(2): for j in range(4): for k in range(5): - expected_ug0[(i+1)*4*5 + j*5 + k] = 1.0 + expected_ug0[(i + 1) * 4 * 5 + j * 5 + k] = 1.0 test_unary_op_backward[ - OP.SLICE, t0_shape, ug0_shape, AttributeVector( + OP.SLICE, + t0_shape, + ug0_shape, + AttributeVector( Attribute("starts", TensorShape(slice_0.start)), Attribute("ends", TensorShape(slice_0.end)), Attribute("steps", TensorShape(slice_0.step)), - Attribute("axes", TensorShape(0)) - ) + Attribute("axes", TensorShape(0)), + ), ](t0, ug0, expected_ug0) # dim = 1 (step = 2) @@ -543,20 +703,23 @@ fn test_backward_SLICE() raises: alias ug1_shape = TensorShape(2, 3, 2) var ug1: Tensor[dtype] = Tensor[dtype](ug1_shape) fill(ug1, 1.0) - + var expected_ug1 = Tensor[dtype](t1_shape) for i in range(2): for j in range(3): for k in range(2): - expected_ug1[i*10*2 + (j*2 + 1)*2 + k] = 1.0 + expected_ug1[i * 10 * 2 + (j * 2 + 1) * 2 + k] = 1.0 test_unary_op_backward[ - OP.SLICE, t1_shape, ug1_shape, AttributeVector( + OP.SLICE, + t1_shape, + ug1_shape, + AttributeVector( Attribute("starts", TensorShape(slice_1.start)), Attribute("ends", TensorShape(slice_1.end)), Attribute("steps", TensorShape(slice_1.step)), - Attribute("axes", TensorShape(1)) - ) + Attribute("axes", TensorShape(1)), + ), ](t1, ug1, expected_ug1) # dim = 2 (step = -2) @@ -573,15 +736,20 @@ fn test_backward_SLICE() raises: for i in range(2): for j in range(2): for k in range(3): - expected_ug2[i*2*10 + j*10 + StaticIntTuple[3](6, 4, 2)[k]] = 1.0 + expected_ug2[ + i * 2 * 10 + j * 10 + StaticIntTuple[3](6, 4, 2)[k] + ] = 1.0 test_unary_op_backward[ - OP.SLICE, t2_shape, ug2_shape, AttributeVector( + OP.SLICE, + t2_shape, + ug2_shape, + AttributeVector( Attribute("starts", TensorShape(slice_2.start)), Attribute("ends", TensorShape(slice_2.end)), Attribute("steps", TensorShape(slice_2.step)), - Attribute("axes", TensorShape(2)) - ) + Attribute("axes", TensorShape(2)), + ), ](t2, ug2, expected_ug2) @@ -599,8 +767,12 @@ fn test_backward_SLICE_multiple_axes() raises: for i in range(3): for j in range(3): for k in range(5): - expected[i*3*5 + j*5 + k] = StaticIntTuple[5](1, 3, 5, 7, 9)[i] * 32 * 40 + StaticIntTuple[3](3, 6, 9)[j] * 40 + StaticIntTuple[5](5, 7, 9, 11, 13)[k] - + expected[i * 3 * 5 + j * 5 + k] = ( + StaticIntTuple[5](1, 3, 5, 7, 9)[i] * 32 * 40 + + StaticIntTuple[3](3, 6, 9)[j] * 40 + + StaticIntTuple[5](5, 7, 9, 11, 13)[k] + ) + alias ug_shape = TensorShape(3, 3, 5) var ug: Tensor[dtype] = Tensor[dtype](ug_shape) fill(ug, 1.0) @@ -609,17 +781,127 @@ fn test_backward_SLICE_multiple_axes() raises: for i in range(3): for j in range(3): for k in range(5): - expected_ug[StaticIntTuple[5](1, 3, 5, 7, 9)[i] * 32 * 40 + StaticIntTuple[3](3, 6, 9)[j] * 40 + StaticIntTuple[5](5, 7, 9, 11, 13)[k]] = 1.0 + expected_ug[ + StaticIntTuple[5](1, 3, 5, 7, 9)[i] * 32 * 40 + + StaticIntTuple[3](3, 6, 9)[j] * 40 + + StaticIntTuple[5](5, 7, 9, 11, 13)[k] + ] = 1.0 test_unary_op_backward[ - OP.SLICE, t1_shape, ug_shape, AttributeVector( - Attribute("starts", TensorShape(slice_0.start, slice_1.start, slice_2.start)), - Attribute("ends", TensorShape(slice_0.end, slice_1.end, slice_2.end)), - Attribute("steps", TensorShape(slice_0.step, slice_1.step, slice_2.step)), - ) + OP.SLICE, + t1_shape, + ug_shape, + AttributeVector( + Attribute( + "starts", + TensorShape(slice_0.start, slice_1.start, slice_2.start), + ), + Attribute( + "ends", TensorShape(slice_0.end, slice_1.end, slice_2.end) + ), + Attribute( + "steps", TensorShape(slice_0.step, slice_1.step, slice_2.step) + ), + ), ](t1, ug, expected_ug) +from basalt.autograd.ops.mlops import INDEX + +fn test_INDEX() raises: + alias t1_shape = TensorShape(2, 3, 5) + var t = Tensor[dtype](t1_shape) + for i in range(t.num_elements()): + t[i] = i + + # t[:, [0, 0], 0:5:2] + # TODO: need for a list attribute as this only supports to specify indeces of MAX_RANK + alias attr_1 = Attribute("dim_1i", TensorShape(0, 0)) + alias attr_2 = Attribute("dim_2s", TensorShape(0, 5, 2)) + + var expected = Tensor[dtype](2, 2, 3) + for i in range(2): + for j in range(2): + for k in range(3): + expected[i*2*3 + j*3 + k] = i * 3 * 5 + k * 2 + + test_unary_op[ + OP.INDEX, t1_shape, AttributeVector( + attr_1, + attr_2, + ) + ](t, expected) + + +fn test_INDEX_backward() raises: + alias t1_shape = TensorShape(2, 3, 5) + var t = Tensor[dtype](t1_shape) + for i in range(t.num_elements()): + t[i] = i + + alias attr_1 = Attribute("dim_1i", TensorShape(0, 0)) + alias attr_2 = Attribute("dim_2s", TensorShape(0, 5, 2)) + + alias ug_shape = TensorShape(2, 2, 3) + var ug = Tensor[dtype](ug_shape) + fill(ug, 1.0) + + var expected = Tensor[dtype](t1_shape) + for i in range(2): + for j in range(2): + for k in range(3): + # NOTE: `+=` because selected indeces [0, 0] can repeat + expected[i * 3 * 5 + k * 2] += 1.0 + + test_unary_op_backward[ + OP.INDEX, t1_shape, ug_shape, AttributeVector( + attr_1, + attr_2, + ) + ](t, ug, expected) + +fn test_UPSAMPLE() raises: + alias t1_shape = TensorShape(2, 3, 5) + var t = Tensor[dtype](t1_shape) + for i in range(t.num_elements()): + t[i] = i + + var expected = Tensor[dtype](2, 3, 10) + for i in range(2): + for j in range(3): + for k in range(5): + for l in range(2): + expected[i*3*10 + j*10 + k*2 + l] = t[i*3*5 + j*5 + k] + + test_unary_op[ + OP.UPSAMPLE, t1_shape, AttributeVector( + Attribute("scales", TensorShape(2)), + Attribute("mode", "nearest") + ) + ](t, expected) + + + alias t2_shape = TensorShape(1, 1, 2, 2) + t = Tensor[dtype](t2_shape) + for i in range(t.num_elements()): + t[i] = i + + expected = Tensor[dtype](1, 1, 4, 6) + for i in range(1): + for j in range(1): + for k in range(4): + for l in range(6): + var pos = i*1*2*2 + j*2*2 + (k // 2) * 2 + (l // 3) + expected[i*1*4*6 + j*4*6 + k*6 + l] = t[pos] + + test_unary_op[ + OP.UPSAMPLE, t2_shape, AttributeVector( + Attribute("scales", TensorShape(2, 3)), + Attribute("mode", "nearest") + ) + ](t, expected) + + fn main(): try: test_SIGMOID() @@ -632,6 +914,8 @@ fn main(): test_SLICE_step() test_SLICE_neg() test_SLICE_multiple_axes() + test_INDEX() + test_UPSAMPLE() except e: print("[ERROR] Error in forward mlops") print(e) @@ -646,6 +930,8 @@ fn main(): test_backward_UNSQUEEZE() test_backward_SLICE() test_backward_SLICE_multiple_axes() + test_INDEX_backward() + pass except e: print("[ERROR] Error in backward mlops") print(e) diff --git a/tests/mojo/test_tensorutils.mojo b/tests/mojo/test_tensorutils.mojo index 1612f050..7ea9cc0a 100644 --- a/tests/mojo/test_tensorutils.mojo +++ b/tests/mojo/test_tensorutils.mojo @@ -1,11 +1,11 @@ from random import rand from testing import assert_equal, assert_almost_equal -from math import sqrt, exp, round, add, sub, mul, div +from math import sqrt, exp from basalt import dtype, nelts +from basalt.autograd.ops.matmul import dot from basalt.utils.tensorutils import ( fill, - dot, elwise_transform, elwise_pow, elwise_op, @@ -20,6 +20,7 @@ from basalt.utils.tensorutils import ( transpose, ) from basalt.nn import Tensor, TensorShape +from basalt.utils.math_util import add, sub, mul, div, round_simd from tests import assert_tensors_equal @@ -81,7 +82,7 @@ fn test_elwise_transform() raises: assert_tensors_equal(B_res, C) var C_res = Tensor[dtype](2, 10) - elwise_transform[round](C_res, C) + elwise_transform[round_simd](C_res, C) assert_tensors_equal(C_res, D) diff --git a/tests/mojo/test_tensorutils_data.mojo b/tests/mojo/test_tensorutils_data.mojo index 4cf956e9..3a7466fe 100644 --- a/tests/mojo/test_tensorutils_data.mojo +++ b/tests/mojo/test_tensorutils_data.mojo @@ -1,8 +1,7 @@ -from math import add - from basalt import dtype, nelts from basalt.nn import Tensor, TensorShape from basalt.utils.tensorutils import fill, elwise_op +from basalt.utils.math_util import add fn generate_tensor(*shape: Int) -> Tensor[dtype]: diff --git a/tests/python/test_mlops_torch.mojo b/tests/python/test_mlops_torch.mojo index 2f4747cb..ba6288f8 100644 --- a/tests/python/test_mlops_torch.mojo +++ b/tests/python/test_mlops_torch.mojo @@ -1,5 +1,5 @@ from random import rand -from math.limit import min_finite, max_finite +from utils.numerics import min_finite, max_finite from collections.optional import OptionalReg, Optional from python.python import Python from python.object import PythonObject @@ -47,6 +47,11 @@ fn torch_unary_op( expected = torch.sigmoid(input_1) elif op == OP.RELU: expected = torch.relu(input_1) + elif op == OP.LEAKYRELU: + expected = torch.nn.functional.leaky_relu( + input_1, + attrs.value()["negative_slope"].value().to_scalar[dtype](), + ) elif op == OP.TANH: expected = torch.tanh(input_1) elif op == OP.CLIP: @@ -65,7 +70,9 @@ fn torch_unary_op( var dim = attrs["dims"] if dim: - expected = torch.squeeze(input_1, dim=dim.value().to_shape()[0]) + expected = torch.squeeze( + input_1, dim=dim.value().to_shape()[0] + ) else: expected = torch.squeeze(input_1) elif attrs_tuple: @@ -78,7 +85,9 @@ fn torch_unary_op( var dim = attrs["dims"] if dim: - expected = torch.unsqueeze(input_1, dim=dim.value().to_shape()[0]) + expected = torch.unsqueeze( + input_1, dim=dim.value().to_shape()[0] + ) else: expected = torch.unsqueeze(input_1, 0) elif attrs_tuple: @@ -102,12 +111,24 @@ fn torch_unary_op( if step < 0: flip_dims.append(dim) - step = step *- 1 + step = step * -1 end, start = (end + 1) * -1, (start + 1) * -1 indices[dim] = py.slice(start, end, step) - + expected = input_1.flip(flip_dims)[indices] + elif op == OP.UPSAMPLE: + var attrs = attrs.value() + var scales = attrs["scales"].value().to_shape() + var mode = attrs["mode"].value().to_string() + + var scales_py = PythonObject([]) + for i in range(scales.rank()): + scales_py.append(scales[i]) + + expected = torch.nn.functional.interpolate( + input_1, scale_factor=scales_py, mode=mode + ) else: print("Error: op not supported (returning the value input_1): ", op) expected = input_1 @@ -159,6 +180,31 @@ fn test_RELU() raises: ) +fn test_LEAKYRELU() raises: + alias t1_shape = TensorShape(37, 63, 107) + alias ug_shape = TensorShape(37, 63, 107) + var t1: Tensor[dtype] = Tensor[dtype](t1_shape) + rand(t1.data(), t1.num_elements()) + + var ug = Tensor[dtype](ug_shape) + rand(ug.data(), ug.num_elements()) + + var expected_and_grad = torch_unary_op( + OP.LEAKYRELU, t1, ug, AttributeVector(Attribute("negative_slope", Float32(0.1))) + ) + test_unary_op[ + OP.LEAKYRELU, + t1_shape, + AttributeVector(Attribute("negative_slope", 0.1)), + ](t1, expected_and_grad.expected) + test_unary_op_backward[ + OP.LEAKYRELU, + t1_shape, + ug_shape, + AttributeVector(Attribute("negative_slope", 0.1)), + ](t1, ug, expected_and_grad.grad_1) + + fn test_TANH() raises: alias t1_shape = TensorShape(37, 63, 107) alias ug_shape = TensorShape(37, 63, 107) @@ -193,23 +239,27 @@ fn test_CLIP() raises: # Clip with min alias min_attr = Attribute("min", 0.3333) - expected_and_grad = torch_unary_op(OP.CLIP, t1, ug, AttributeVector(min_attr)) + expected_and_grad = torch_unary_op( + OP.CLIP, t1, ug, AttributeVector(min_attr) + ) test_unary_op[OP.CLIP, t1_shape, AttributeVector(min_attr)]( t1, expected_and_grad.expected ) - test_unary_op_backward[OP.CLIP, t1_shape, ug_shape, AttributeVector(min_attr)]( - t1, ug, expected_and_grad.grad_1 - ) + test_unary_op_backward[ + OP.CLIP, t1_shape, ug_shape, AttributeVector(min_attr) + ](t1, ug, expected_and_grad.grad_1) # Clip with max alias max_attr = Attribute("max", 0.6666) - expected_and_grad = torch_unary_op(OP.CLIP, t1, ug, AttributeVector(max_attr)) + expected_and_grad = torch_unary_op( + OP.CLIP, t1, ug, AttributeVector(max_attr) + ) test_unary_op[OP.CLIP, t1_shape, AttributeVector(max_attr)]( t1, expected_and_grad.expected ) - test_unary_op_backward[OP.CLIP, t1_shape, ug_shape, AttributeVector(max_attr)]( - t1, ug, expected_and_grad.grad_1 - ) + test_unary_op_backward[ + OP.CLIP, t1_shape, ug_shape, AttributeVector(max_attr) + ](t1, ug, expected_and_grad.grad_1) # Clip with min and max expected_and_grad = torch_unary_op( @@ -249,9 +299,9 @@ fn test_SQUEEZE() raises: test_unary_op[OP.SQUEEZE, t1_shape, AttributeVector(dim)]( t1, expected_and_grad.expected ) - test_unary_op_backward[OP.SQUEEZE, t1_shape, ug_shape_1, AttributeVector(dim)]( - t1, ug, expected_and_grad.grad_1 - ) + test_unary_op_backward[ + OP.SQUEEZE, t1_shape, ug_shape_1, AttributeVector(dim) + ](t1, ug, expected_and_grad.grad_1) alias ug_shape_2 = TensorShape(20, 28, 1) ug = Tensor[dtype](ug_shape_2) @@ -259,13 +309,15 @@ fn test_SQUEEZE() raises: alias dim_2 = Attribute("dims", TensorShape(1)) - expected_and_grad = torch_unary_op(OP.SQUEEZE, t1, ug, AttributeVector(dim_2)) + expected_and_grad = torch_unary_op( + OP.SQUEEZE, t1, ug, AttributeVector(dim_2) + ) test_unary_op[OP.SQUEEZE, t1_shape, AttributeVector(dim_2)]( t1, expected_and_grad.expected ) - test_unary_op_backward[OP.SQUEEZE, t1_shape, ug_shape_2, AttributeVector(dim_2)]( - t1, ug, expected_and_grad.grad_1 - ) + test_unary_op_backward[ + OP.SQUEEZE, t1_shape, ug_shape_2, AttributeVector(dim_2) + ](t1, ug, expected_and_grad.grad_1) # Squeeze with multiple dims ug = Tensor[dtype](ug_shape) @@ -282,9 +334,9 @@ fn test_SQUEEZE() raises: test_unary_op[OP.SQUEEZE, t1_shape, AttributeVector(dims)]( t1, expected_and_grad.expected ) - test_unary_op_backward[OP.SQUEEZE, t1_shape, ug_shape, AttributeVector(dims)]( - t1, ug, expected_and_grad.grad_1 - ) + test_unary_op_backward[ + OP.SQUEEZE, t1_shape, ug_shape, AttributeVector(dims) + ](t1, ug, expected_and_grad.grad_1) fn test_UNSQUEEZE() raises: @@ -298,13 +350,15 @@ fn test_UNSQUEEZE() raises: alias dim = Attribute("dims", TensorShape(1)) - var expected_and_grad = torch_unary_op(OP.UNSQUEEZE, t1, ug, AttributeVector(dim)) + var expected_and_grad = torch_unary_op( + OP.UNSQUEEZE, t1, ug, AttributeVector(dim) + ) test_unary_op[OP.UNSQUEEZE, t1_shape, AttributeVector(dim)]( t1, expected_and_grad.expected ) - test_unary_op_backward[OP.UNSQUEEZE, t1_shape, ug_shape, AttributeVector(dim)]( - t1, ug, expected_and_grad.grad_1 - ) + test_unary_op_backward[ + OP.UNSQUEEZE, t1_shape, ug_shape, AttributeVector(dim) + ](t1, ug, expected_and_grad.grad_1) # Unsqueeze with multiple dims alias ug_shape_2 = TensorShape(20, 1, 28, 1) @@ -321,9 +375,9 @@ fn test_UNSQUEEZE() raises: test_unary_op[OP.UNSQUEEZE, t1_shape, AttributeVector(dims)]( t1, expected_and_grad.expected ) - test_unary_op_backward[OP.UNSQUEEZE, t1_shape, ug_shape_2, AttributeVector(dims)]( - t1, ug, expected_and_grad.grad_1 - ) + test_unary_op_backward[ + OP.UNSQUEEZE, t1_shape, ug_shape_2, AttributeVector(dims) + ](t1, ug, expected_and_grad.grad_1) fn test_SLICE() raises: @@ -337,17 +391,23 @@ fn test_SLICE() raises: Attribute("starts", TensorShape(slice_0.start)), Attribute("ends", TensorShape(slice_0.end)), Attribute("steps", TensorShape(slice_0.step)), - Attribute("axes", TensorShape(0)) + Attribute("axes", TensorShape(0)), ) alias ug_shape = TensorShape(65, 322, 317) var ug = Tensor[dtype](ug_shape) rand(ug.data(), ug.num_elements()) - var attrs_tuple_0 = PythonObject((slice_0.start, slice_0.end, slice_0.step, 0)) - var expected_and_grad = torch_unary_op(OP.SLICE, t1, ug, attrs_tuple=attrs_tuple_0) + var attrs_tuple_0 = PythonObject( + (slice_0.start, slice_0.end, slice_0.step, 0) + ) + var expected_and_grad = torch_unary_op( + OP.SLICE, t1, ug, attrs_tuple=attrs_tuple_0 + ) test_unary_op[OP.SLICE, t1_shape, attrs_0](t1, expected_and_grad.expected) - test_unary_op_backward[OP.SLICE, t1_shape, ug_shape, attrs_0](t1, ug, expected_and_grad.grad_1) + test_unary_op_backward[OP.SLICE, t1_shape, ug_shape, attrs_0]( + t1, ug, expected_and_grad.grad_1 + ) # dim = 1 alias slice_1 = Slice(10, 311, 5) @@ -355,17 +415,23 @@ fn test_SLICE() raises: Attribute("starts", TensorShape(slice_1.start)), Attribute("ends", TensorShape(slice_1.end)), Attribute("steps", TensorShape(slice_1.step)), - Attribute("axes", TensorShape(1)) + Attribute("axes", TensorShape(1)), ) alias ug_shape_1 = TensorShape(430, 61, 317) ug = Tensor[dtype](ug_shape_1) rand(ug.data(), ug.num_elements()) - var attrs_tuple_1 = PythonObject((slice_1.start, slice_1.end, slice_1.step, 1)) - expected_and_grad = torch_unary_op(OP.SLICE, t1, ug, attrs_tuple=attrs_tuple_1) + var attrs_tuple_1 = PythonObject( + (slice_1.start, slice_1.end, slice_1.step, 1) + ) + expected_and_grad = torch_unary_op( + OP.SLICE, t1, ug, attrs_tuple=attrs_tuple_1 + ) test_unary_op[OP.SLICE, t1_shape, attrs_1](t1, expected_and_grad.expected) - test_unary_op_backward[OP.SLICE, t1_shape, ug_shape_1, attrs_1](t1, ug, expected_and_grad.grad_1) + test_unary_op_backward[OP.SLICE, t1_shape, ug_shape_1, attrs_1]( + t1, ug, expected_and_grad.grad_1 + ) # dim = 2 alias slice_2 = Slice(293, 33, -7) @@ -373,20 +439,26 @@ fn test_SLICE() raises: Attribute("starts", TensorShape(slice_2.start)), Attribute("ends", TensorShape(slice_2.end)), Attribute("steps", TensorShape(slice_2.step)), - Attribute("axes", TensorShape(2)) + Attribute("axes", TensorShape(2)), ) alias ug_shape_2 = TensorShape(430, 322, 38) ug = Tensor[dtype](ug_shape_2) rand(ug.data(), ug.num_elements()) - var attrs_tuple_2 = PythonObject((slice_2.start, slice_2.end, slice_2.step, 2)) - expected_and_grad = torch_unary_op(OP.SLICE, t1, ug, attrs_tuple=attrs_tuple_2) + var attrs_tuple_2 = PythonObject( + (slice_2.start, slice_2.end, slice_2.step, 2) + ) + expected_and_grad = torch_unary_op( + OP.SLICE, t1, ug, attrs_tuple=attrs_tuple_2 + ) test_unary_op[OP.SLICE, t1_shape, attrs_2](t1, expected_and_grad.expected) - test_unary_op_backward[OP.SLICE, t1_shape, ug_shape_2, attrs_2](t1, ug, expected_and_grad.grad_1) + test_unary_op_backward[OP.SLICE, t1_shape, ug_shape_2, attrs_2]( + t1, ug, expected_and_grad.grad_1 + ) # Multiple dims - + # dim = 0, 1 alias slice_0_1 = Slice(23, 340, 3) alias slice_1_1 = Slice(10, 250, 5) @@ -395,17 +467,32 @@ fn test_SLICE() raises: Attribute("starts", TensorShape(slice_0_1.start, slice_1_1.start)), Attribute("ends", TensorShape(slice_0_1.end, slice_1_1.end)), Attribute("steps", TensorShape(slice_0_1.step, slice_1_1.step)), - Attribute("axes", TensorShape(0, 1)) + Attribute("axes", TensorShape(0, 1)), ) alias ug_shape_0_1 = TensorShape(106, 48, 317) ug = Tensor[dtype](ug_shape_0_1) rand(ug.data(), ug.num_elements()) - var attrs_tuple_0_1 = PythonObject((slice_0_1.start, slice_0_1.end, slice_0_1.step, 0, slice_1_1.start, slice_1_1.end, slice_1_1.step, 1)) - expected_and_grad = torch_unary_op(OP.SLICE, t1, ug, attrs_tuple=attrs_tuple_0_1) + var attrs_tuple_0_1 = PythonObject( + ( + slice_0_1.start, + slice_0_1.end, + slice_0_1.step, + 0, + slice_1_1.start, + slice_1_1.end, + slice_1_1.step, + 1, + ) + ) + expected_and_grad = torch_unary_op( + OP.SLICE, t1, ug, attrs_tuple=attrs_tuple_0_1 + ) test_unary_op[OP.SLICE, t1_shape, attrs_0_1](t1, expected_and_grad.expected) - test_unary_op_backward[OP.SLICE, t1_shape, ug_shape_0_1, attrs_0_1](t1, ug, expected_and_grad.grad_1) + test_unary_op_backward[OP.SLICE, t1_shape, ug_shape_0_1, attrs_0_1]( + t1, ug, expected_and_grad.grad_1 + ) # dim = 0, 1, 2 alias slice_0_2 = Slice(-412, -5, 3) @@ -413,20 +500,112 @@ fn test_SLICE() raises: alias slice_2_2 = Slice(293, 33, -7) alias attrs_0_2 = AttributeVector( - Attribute("starts", TensorShape(slice_0_2.start, slice_1_2.start, slice_2_2.start)), - Attribute("ends", TensorShape(slice_0_2.end, slice_1_2.end, slice_2_2.end)), - Attribute("steps", TensorShape(slice_0_2.step, slice_1_2.step, slice_2_2.step)), - Attribute("axes", TensorShape(0, 1, 2)) + Attribute( + "starts", + TensorShape(slice_0_2.start, slice_1_2.start, slice_2_2.start), + ), + Attribute( + "ends", TensorShape(slice_0_2.end, slice_1_2.end, slice_2_2.end) + ), + Attribute( + "steps", TensorShape(slice_0_2.step, slice_1_2.step, slice_2_2.step) + ), + Attribute("axes", TensorShape(0, 1, 2)), ) alias ug_shape_0_2 = TensorShape(136, 35, 38) ug = Tensor[dtype](ug_shape_0_2) rand(ug.data(), ug.num_elements()) - var attrs_tuple_0_2 = PythonObject((slice_0_2.start, slice_0_2.end, slice_0_2.step, 0, slice_1_2.start, slice_1_2.end, slice_1_2.step, 1, slice_2_2.start, slice_2_2.end, slice_2_2.step, 2)) - expected_and_grad = torch_unary_op(OP.SLICE, t1, ug, attrs_tuple=attrs_tuple_0_2) + var attrs_tuple_0_2 = PythonObject( + ( + slice_0_2.start, + slice_0_2.end, + slice_0_2.step, + 0, + slice_1_2.start, + slice_1_2.end, + slice_1_2.step, + 1, + slice_2_2.start, + slice_2_2.end, + slice_2_2.step, + 2, + ) + ) + expected_and_grad = torch_unary_op( + OP.SLICE, t1, ug, attrs_tuple=attrs_tuple_0_2 + ) test_unary_op[OP.SLICE, t1_shape, attrs_0_2](t1, expected_and_grad.expected) - test_unary_op_backward[OP.SLICE, t1_shape, ug_shape_0_2, attrs_0_2](t1, ug, expected_and_grad.grad_1) + test_unary_op_backward[OP.SLICE, t1_shape, ug_shape_0_2, attrs_0_2]( + t1, ug, expected_and_grad.grad_1 + ) + + +fn test_UPSAMPLE() raises: + alias t1_shape = TensorShape(41, 41, 43) + var t1 = Tensor[dtype](t1_shape) + rand(t1.data(), t1.num_elements()) + + alias attributes = AttributeVector( + Attribute("scales", TensorShape(3)), + Attribute("mode", "nearest") + ) + + alias ug_shape = TensorShape(41, 41, 129) + var ug = Tensor[dtype](ug_shape) + rand(ug.data(), ug.num_elements()) + + var expected_and_grad = torch_unary_op(OP.UPSAMPLE, t1, ug, attributes) + test_unary_op[OP.UPSAMPLE, t1_shape, attributes](t1, expected_and_grad.expected) + + alias attributes_2 = AttributeVector( + Attribute("scales", TensorShape(3)), + Attribute("mode", "linear") + ) + + expected_and_grad = torch_unary_op(OP.UPSAMPLE, t1, ug, attributes_2) + test_unary_op[OP.UPSAMPLE, t1_shape, attributes_2](t1, expected_and_grad.expected) + + alias t1_shape_1 = TensorShape(20, 20, 120, 120) + t1 = Tensor[dtype](t1_shape_1) + rand(t1.data(), t1.num_elements()) + + alias attributes_3 = AttributeVector( + Attribute("scales", TensorShape(2, 3)), + Attribute("mode", "nearest") + ) + + alias ug_shape_1 = TensorShape(20, 20, 240, 360) + ug = Tensor[dtype](ug_shape_1) + rand(ug.data(), ug.num_elements()) + + expected_and_grad = torch_unary_op(OP.UPSAMPLE, t1, ug, attributes_3) + test_unary_op[OP.UPSAMPLE, t1_shape_1, attributes_3](t1, expected_and_grad.expected) + + alias attributes_4 = AttributeVector( + Attribute("scales", TensorShape(2, 3)), + Attribute("mode", "bilinear") + ) + + expected_and_grad = torch_unary_op(OP.UPSAMPLE, t1, ug, attributes_4) + test_unary_op[OP.UPSAMPLE, t1_shape_1, attributes_4](t1, expected_and_grad.expected) + + alias t1_shape_2 = TensorShape(5, 5, 10, 20, 60) + t1 = Tensor[dtype](t1_shape_2) + rand(t1.data(), t1.num_elements()) + + alias attributes_5 = AttributeVector( + Attribute("scales", TensorShape(2, 3, 4)), + Attribute("mode", "nearest") + ) + + alias ug_shape_2 = TensorShape(5, 5, 20, 60, 240) + ug = Tensor[dtype](ug_shape_2) + rand(ug.data(), ug.num_elements()) + + expected_and_grad = torch_unary_op(OP.UPSAMPLE, t1, ug, attributes_5) + test_unary_op[OP.UPSAMPLE, t1_shape_2, attributes_5](t1, expected_and_grad.expected) fn main(): @@ -434,11 +613,13 @@ fn main(): try: test_SIGMOID() test_RELU() + test_LEAKYRELU() test_TANH() test_CLIP() test_SQUEEZE() test_UNSQUEEZE() test_SLICE() + test_UPSAMPLE() except e: print("[ERROR] Error in mlops (compare with torch)") print(e) diff --git a/tests/python/test_models_mnist.mojo b/tests/python/test_models_mnist.mojo index 85dd47df..5a0312d7 100644 --- a/tests/python/test_models_mnist.mojo +++ b/tests/python/test_models_mnist.mojo @@ -120,7 +120,7 @@ fn run_mojo[ ) var model = Model[graph]() - var optim = optim.Adam[graph](Reference(model.parameters), lr=learning_rate) + var optim = optim.Adam[graph](model.parameters, lr=learning_rate) var losses = List[Scalar[dtype]]() diff --git a/tests/python/test_models_regression.mojo b/tests/python/test_models_regression.mojo index cc884442..1a36b77b 100644 --- a/tests/python/test_models_regression.mojo +++ b/tests/python/test_models_regression.mojo @@ -1,6 +1,6 @@ from random import rand from python import Python -from math.limit import max_finite +from utils.numerics import max_finite from testing import assert_almost_equal from basalt import dtype diff --git a/tests/python/test_models_sin_estimate.mojo b/tests/python/test_models_sin_estimate.mojo index fe6e2a4d..9b592314 100644 --- a/tests/python/test_models_sin_estimate.mojo +++ b/tests/python/test_models_sin_estimate.mojo @@ -1,6 +1,6 @@ from random import rand from python import Python -from math.limit import max_finite +from utils.numerics import max_finite from testing import assert_almost_equal from basalt import dtype @@ -81,7 +81,7 @@ fn run_mojo[ ) var model = Model[graph]() - var optim = optim.Adam[graph](Reference(model.parameters), lr=learning_rate) + var optim = optim.Adam[graph](model.parameters, lr=learning_rate) var losses = List[Scalar[dtype]]() diff --git a/tests/python/test_upsample.mojo b/tests/python/test_upsample.mojo new file mode 100644 index 00000000..c5918ffe --- /dev/null +++ b/tests/python/test_upsample.mojo @@ -0,0 +1,159 @@ +from python.python import Python, PythonObject + +import basalt.nn as nn +from basalt import dtype, Graph +from basalt import Tensor, TensorShape +from tests import assert_tensors_equal, to_numpy, to_tensor + + +fn test_upsample[ + shape: TensorShape, + mode: StringLiteral, + scale_factors: List[Scalar[dtype]], + align_corners: Bool +]( + t1: Tensor[dtype], + ug: Tensor[dtype], + expected: Tensor[dtype], + expected_grad: Tensor[dtype] +) raises: + + fn create_graph() -> Graph: + var g = Graph() + var t1 = g.input(shape, trainable=True) + var t2 = nn.Upsample(g, t1, mode, scale_factors, align_corners) + g.out(t2) + return g ^ + + alias graph = create_graph() + var model = nn.Model[graph](inference_only=True) + var res = model.inference(t1)[0] + + model.backward(ug) + var res_grad = model.parameters.grads[graph.inputs[0]] + + assert_tensors_equal["almost"](res, expected) + assert_tensors_equal["almost"](res_grad, expected_grad) + + +@value +struct torch_upsample_result: + var expected: Tensor[dtype] + var grad: Tensor[dtype] + + +fn test_upsample_torch[ + shape: TensorShape, + mode: StringLiteral, + scale_factors: List[Scalar[dtype]], + align_corners: Bool +](data: PythonObject, ug: PythonObject) raises -> torch_upsample_result: + + var py = Python.import_module("builtins") + var np = Python.import_module("numpy") + var torch = Python.import_module("torch") + + var py_scales = py.list() + for i in range(len(scale_factors)): + py_scales.append(scale_factors[i]) + + # if mode == "nearest": + # var ups = torch.nn.Upsample(scale_factor=py.tuple(py_scales), mode=mode) + # else: + # var ups = torch.nn.Upsample(scale_factor=py.tuple(py_scales), mode=mode, align_corners=align_corners) + + var ups = torch.nn.Upsample(scale_factor=py.tuple(py_scales), mode=mode) + + var tensor = torch.from_numpy(data).requires_grad_(True) + var expected = ups(tensor) + var upper_grad = torch.from_numpy(ug) + _ = expected.backward(upper_grad) + + return torch_upsample_result( + to_tensor(expected.detach().numpy()), + to_tensor(tensor.grad.numpy()), + ) + + + +fn test_UPSAMPLE_nearest() raises: + var np = Python.import_module("numpy") + + alias shape = TensorShape(1, 1, 2, 2) + alias mode: StringLiteral = "nearest" + alias scales = List[Scalar[dtype]](2.0, 3.0) + alias align_corners = False + + var data = np.array([ + 1, 2, + 3, 4 + ], dtype=np.float32).reshape(1, 1, 2, 2) + + var ug = np.ones((1, 1, 4, 6)) + + var torch_out = test_upsample_torch[shape, mode, scales, align_corners](data, ug) + test_upsample[shape, mode, scales, align_corners]( + to_tensor(data), + to_tensor(ug), + torch_out.expected, + torch_out.grad + ) + + _ = data + + +fn test_UPSAMPLE_linear() raises: + var np = Python.import_module("numpy") + + alias shape = TensorShape(1, 1, 2, 2) + alias mode: StringLiteral = "linear" + alias scales = List[Scalar[dtype]](2.0, 2.0) + + var data = np.array([ + 1, 2, + 3, 4 + ], dtype=np.float32).reshape(1, 1, 2, 2) + + # var expected = np.array([ + # 1., 1.25, 1.75, 2. , + # 1.5, 1.75, 2.25, 2.5 , + # 2.5, 2.75, 3.25, 3.5 , + # 3., 3.25, 3.75, 4. , + # ], dtype=np.float32).reshape(1, 1, 4, 4) + + +fn test_UPSAMPLE_cubic() raises: + var np = Python.import_module("numpy") + + alias shape = TensorShape(1, 1, 4, 4) + alias mode: StringLiteral = "cubic" + alias scales = List[Scalar[dtype]](2.0, 2.0) + + var data = np.array([ + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + ], dtype=np.float32).reshape(1, 1, 4, 4) + + # var expected = np.array([ + # 0.47265625, 0.76953125, 1.24609375, 1.875, 2.28125, 2.91015625, 3.38671875, 3.68359375, + # 1.66015625, 1.95703125, 2.43359375, 3.0625, 3.46875, 4.09765625, 4.57421875, 4.87109375, + # 3.56640625, 3.86328125, 4.33984375, 4.96875, 5.375, 6.00390625, 6.48046875, 6.77734375, + # 6.08203125, 6.37890625, 6.85546875, 7.484375, 7.890625, 8.51953125, 8.99609375, 9.29296875, + # 7.70703125, 8.00390625, 8.48046875, 9.109375, 9.515625, 10.14453125, 10.62109375, 10.91796875, + # 10.22265625, 10.51953125, 10.99609375, 11.625, 12.03125, 12.66015625, 13.13671875, 13.43359375, + # 12.12890625, 12.42578125, 12.90234375, 13.53125, 13.9375, 14.56640625, 15.04296875, 15.33984375, + # 13.31640625, 13.61328125, 14.08984375, 14.71875, 15.125, 15.75390625, 16.23046875, 16.52734375 + # ], dtype=np.float32).reshape(1, 1, 8, 8) + + +fn main(): + + try: + test_UPSAMPLE_nearest() + # test_UPSAMPLE_linear() + # test_UPSAMPLE_cubic() + except e: + print("[Error] Error in Upsample") + print(e) \ No newline at end of file diff --git a/tests/testing_utils.mojo b/tests/testing_utils.mojo index e28d3114..3706f541 100644 --- a/tests/testing_utils.mojo +++ b/tests/testing_utils.mojo @@ -1,12 +1,14 @@ from python.python import Python from collections import OptionalReg from testing import assert_equal, assert_almost_equal +from algorithm import vectorize from basalt import dtype from basalt.autograd import Graph, OP from basalt.autograd.ops.ops import backward_op from basalt.autograd.attributes import AttributeVector from basalt.nn import Tensor, TensorShape, Model +from basalt.utils.tensor_creation_utils import to_numpy, to_tensor # The below regex should be used to convert deprecated calls @@ -19,13 +21,24 @@ fn assert_tensors_equal[ mode == "exact" or mode == "almost", "Mode must be either 'exact' or 'almost'" ]() + alias nelts = simdwidthof[dtype]() + assert_equal(t1.shape(), t2.shape(), "Tensor shape mismatch") - for i in range(t1.num_elements()): + @parameter + fn v_iter[nelts: Int](i: Int) raises: + @parameter if mode == "almost": - assert_almost_equal(t1[i], t2[i], rtol=1e-5, atol=1e-5, msg=msg) + assert_almost_equal(t1.load[nelts](i), t2.load[nelts](i), rtol=1e-5, atol=1e-5, msg=msg) else: - assert_equal(t1[i], t2[i], msg=msg) + assert_equal(t1.load[nelts](i), t2.load[nelts](i), msg=msg) + + for i in range(0, t1.num_elements() - nelts + 1, nelts): + v_iter[nelts](i) + + # Check the remaining elements + for i in range(nelts * (t1.num_elements() // nelts), t1.num_elements()): + v_iter[1](i) fn test_unary_op[ @@ -176,59 +189,6 @@ fn test_ternary_op_backward[ assert_tensors_equal["almost"](grad_3, grad_3_expected) -fn to_numpy(tensor: Tensor) -> PythonObject: - try: - var np = Python.import_module("numpy") - - np.set_printoptions(4) - - var rank = tensor.rank() - var dims = PythonObject([]) - for i in range(rank): - dims.append(tensor.dim(i)) - var pyarray: PythonObject = np.empty(dims, dtype=np.float32) - - var pointer = int(pyarray.__array_interface__['data'][0].to_float64()) - var pointer_d = DTypePointer[tensor.dtype](address=pointer) - memcpy(pointer_d, tensor.data(), tensor.num_elements()) - - _ = tensor - - return pyarray ^ - except e: - print("Error in to numpy", e) - return PythonObject() - - -fn to_tensor(np_array: PythonObject) raises -> Tensor[dtype]: - var shape = List[Int]() - for i in range(np_array.ndim): - shape.append(int(np_array.shape[i].to_float64())) - if np_array.ndim == 0: - # When the numpy array is a scalar, you need or the reshape to a size 1 ndarray or do this, if not the memcpy gets a memory error (Maybe because it is a register value?). - var tensor = Tensor[dtype](TensorShape(1)) - tensor[0] = np_array.to_float64().cast[dtype]() - return tensor ^ - - var tensor = Tensor[dtype](TensorShape(shape)) - - var np_array_2 = np_array.copy() - try: - var np = Python.import_module("numpy") - np_array_2 = np.float32(np_array_2) - except e: - print("Error in to tensor", e) - - var pointer = int(np_array_2.__array_interface__['data'][0].to_float64()) - var pointer_d = DTypePointer[tensor.dtype](address=pointer) - memcpy(tensor.data(), pointer_d, tensor.num_elements()) - - _ = np_array_2 - _ = np_array - - return tensor ^ - - fn create_graph_concat( t1_shape: TensorShape, t2_shape: TensorShape, t3_shape: TensorShape, dim: Int ) -> Graph: