diff --git a/code_to_optimize/go/algorithms.go b/code_to_optimize/go/algorithms.go new file mode 100644 index 000000000..47961e8ff --- /dev/null +++ b/code_to_optimize/go/algorithms.go @@ -0,0 +1,195 @@ +package sample + +import "strings" + +func TwoSum(nums []int, target int) [2]int { + for i := 0; i < len(nums); i++ { + for j := i + 1; j < len(nums); j++ { + if nums[i]+nums[j] == target { + return [2]int{i, j} + } + } + } + return [2]int{-1, -1} +} + +func FindDuplicates(nums []int) []int { + var result []int + for i := 0; i < len(nums); i++ { + found := false + for j := 0; j < i; j++ { + if nums[i] == nums[j] { + found = true + break + } + } + if found { + alreadyAdded := false + for _, r := range result { + if r == nums[i] { + alreadyAdded = true + break + } + } + if !alreadyAdded { + result = append(result, nums[i]) + } + } + } + return result +} + +func UniqueElements(nums []int) []int { + var result []int + for _, num := range nums { + found := false + for _, r := range result { + if r == num { + found = true + break + } + } + if !found { + result = append(result, num) + } + } + return result +} + +func MostFrequent(nums []int) int { + if len(nums) == 0 { + return 0 + } + + maxCount := 0 + maxNum := nums[0] + + for _, num := range nums { + count := 0 + for _, other := range nums { + if other == num { + count++ + } + } + if count > maxCount { + maxCount = count + maxNum = num + } + } + return maxNum +} + +func Intersection(a, b []int) []int { + var result []int + for _, x := range a { + for _, y := range b { + if x == y { + already := false + for _, r := range result { + if r == x { + already = true + break + } + } + if !already { + result = append(result, x) + } + } + } + } + return result +} + +func MergeSortedSlices(a, b []int) []int { + var result []int + result = append(result, a...) + result = append(result, b...) + + for i := 0; i < len(result); i++ { + for j := i + 1; j < len(result); j++ { + if result[j] < result[i] { + result[i], result[j] = result[j], result[i] + } + } + } + return result +} + +func LongestCommonPrefix(strs []string) string { + if len(strs) == 0 { + return "" + } + + prefix := strs[0] + for _, s := range strs[1:] { + for !strings.HasPrefix(s, prefix) { + prefix = prefix[:len(prefix)-1] + if prefix == "" { + return "" + } + } + } + return prefix +} + +func MaxSubarraySum(nums []int) int { + if len(nums) == 0 { + return 0 + } + + maxSum := nums[0] + for i := 0; i < len(nums); i++ { + for j := i; j < len(nums); j++ { + sum := 0 + for k := i; k <= j; k++ { + sum += nums[k] + } + if sum > maxSum { + maxSum = sum + } + } + } + return maxSum +} + +func IsPrime(n int) bool { + if n < 2 { + return false + } + for i := 2; i < n; i++ { + if n%i == 0 { + return false + } + } + return true +} + +func PrimesUpTo(limit int) []int { + var primes []int + for i := 2; i <= limit; i++ { + if IsPrime(i) { + primes = append(primes, i) + } + } + return primes +} + +func GCD(a, b int) int { + if a < 0 { + a = -a + } + if b < 0 { + b = -b + } + for b != 0 { + a, b = b, a%b + } + return a +} + +func LCM(a, b int) int { + if a == 0 || b == 0 { + return 0 + } + return a / GCD(a, b) * b +} diff --git a/code_to_optimize/go/algorithms_test.go b/code_to_optimize/go/algorithms_test.go new file mode 100644 index 000000000..a6ebc1485 --- /dev/null +++ b/code_to_optimize/go/algorithms_test.go @@ -0,0 +1,165 @@ +package sample + +import ( + "reflect" + "testing" +) + +func TestTwoSum(t *testing.T) { + got := TwoSum([]int{2, 7, 11, 15}, 9) + if got != [2]int{0, 1} { + t.Errorf("TwoSum([2,7,11,15], 9) = %v, want [0,1]", got) + } + + got = TwoSum([]int{1, 2, 3}, 10) + if got != [2]int{-1, -1} { + t.Errorf("TwoSum no match = %v, want [-1,-1]", got) + } +} + +func TestFindDuplicates(t *testing.T) { + got := FindDuplicates([]int{1, 2, 3, 2, 4, 3, 5}) + want := []int{2, 3} + if !reflect.DeepEqual(got, want) { + t.Errorf("FindDuplicates = %v, want %v", got, want) + } + + got = FindDuplicates([]int{1, 2, 3}) + if len(got) != 0 { + t.Errorf("expected no duplicates, got %v", got) + } +} + +func TestUniqueElements(t *testing.T) { + got := UniqueElements([]int{1, 2, 2, 3, 3, 3, 4}) + want := []int{1, 2, 3, 4} + if !reflect.DeepEqual(got, want) { + t.Errorf("UniqueElements = %v, want %v", got, want) + } +} + +func TestMostFrequent(t *testing.T) { + got := MostFrequent([]int{1, 2, 2, 3, 3, 3, 2, 2}) + if got != 2 { + t.Errorf("MostFrequent = %d, want 2", got) + } + + got = MostFrequent([]int{}) + if got != 0 { + t.Errorf("MostFrequent empty = %d, want 0", got) + } +} + +func TestIntersection(t *testing.T) { + got := Intersection([]int{1, 2, 3, 4}, []int{3, 4, 5, 6}) + want := []int{3, 4} + if !reflect.DeepEqual(got, want) { + t.Errorf("Intersection = %v, want %v", got, want) + } + + got = Intersection([]int{1, 2}, []int{3, 4}) + if len(got) != 0 { + t.Errorf("expected empty intersection, got %v", got) + } +} + +func TestMergeSortedSlices(t *testing.T) { + got := MergeSortedSlices([]int{1, 3, 5}, []int{2, 4, 6}) + want := []int{1, 2, 3, 4, 5, 6} + if !reflect.DeepEqual(got, want) { + t.Errorf("MergeSortedSlices = %v, want %v", got, want) + } +} + +func TestLongestCommonPrefix(t *testing.T) { + got := LongestCommonPrefix([]string{"flower", "flow", "flight"}) + if got != "fl" { + t.Errorf("LongestCommonPrefix = %q, want \"fl\"", got) + } + + got = LongestCommonPrefix([]string{"dog", "racecar", "car"}) + if got != "" { + t.Errorf("LongestCommonPrefix = %q, want \"\"", got) + } + + got = LongestCommonPrefix([]string{}) + if got != "" { + t.Errorf("LongestCommonPrefix empty = %q, want \"\"", got) + } +} + +func TestMaxSubarraySum(t *testing.T) { + got := MaxSubarraySum([]int{-2, 1, -3, 4, -1, 2, 1, -5, 4}) + if got != 6 { + t.Errorf("MaxSubarraySum = %d, want 6", got) + } + + got = MaxSubarraySum([]int{-1, -2, -3}) + if got != -1 { + t.Errorf("MaxSubarraySum all negative = %d, want -1", got) + } + + got = MaxSubarraySum([]int{}) + if got != 0 { + t.Errorf("MaxSubarraySum empty = %d, want 0", got) + } +} + +func TestIsPrime(t *testing.T) { + primes := []int{2, 3, 5, 7, 11, 13, 17, 19, 23} + for _, p := range primes { + if !IsPrime(p) { + t.Errorf("IsPrime(%d) = false, want true", p) + } + } + + nonPrimes := []int{0, 1, 4, 6, 8, 9, 10, 15} + for _, n := range nonPrimes { + if IsPrime(n) { + t.Errorf("IsPrime(%d) = true, want false", n) + } + } +} + +func TestPrimesUpTo(t *testing.T) { + got := PrimesUpTo(20) + want := []int{2, 3, 5, 7, 11, 13, 17, 19} + if !reflect.DeepEqual(got, want) { + t.Errorf("PrimesUpTo(20) = %v, want %v", got, want) + } +} + +func TestGCD(t *testing.T) { + tests := []struct { + a, b, want int + }{ + {12, 8, 4}, + {7, 13, 1}, + {0, 5, 5}, + {-12, 8, 4}, + } + + for _, tc := range tests { + got := GCD(tc.a, tc.b) + if got != tc.want { + t.Errorf("GCD(%d, %d) = %d, want %d", tc.a, tc.b, got, tc.want) + } + } +} + +func TestLCM(t *testing.T) { + tests := []struct { + a, b, want int + }{ + {4, 6, 12}, + {7, 13, 91}, + {0, 5, 0}, + } + + for _, tc := range tests { + got := LCM(tc.a, tc.b) + if got != tc.want { + t.Errorf("LCM(%d, %d) = %d, want %d", tc.a, tc.b, got, tc.want) + } + } +} diff --git a/code_to_optimize/go/calculator.go b/code_to_optimize/go/calculator.go new file mode 100644 index 000000000..161537293 --- /dev/null +++ b/code_to_optimize/go/calculator.go @@ -0,0 +1,117 @@ +package sample + +import "math" + +func Factorial(n int) int64 { + if n < 0 { + panic("factorial not defined for negative numbers") + } + if n <= 1 { + return 1 + } + return int64(n) * Factorial(n-1) +} + +func Power(base float64, exp int) float64 { + if exp < 0 { + return 1.0 / Power(base, -exp) + } + if exp == 0 { + return 1 + } + result := 1.0 + for i := 0; i < exp; i++ { + result *= base + } + return result +} + +func SumRange(start, end int) int64 { + var sum int64 + for i := start; i <= end; i++ { + sum += int64(i) + } + return sum +} + +func Average(nums []float64) float64 { + if len(nums) == 0 { + return 0 + } + sum := 0.0 + for _, n := range nums { + sum = sum + n + } + return sum / float64(len(nums)) +} + +func StandardDeviation(nums []float64) float64 { + if len(nums) == 0 { + return 0 + } + avg := Average(nums) + sumSqDiff := 0.0 + for _, n := range nums { + diff := n - avg + sumSqDiff = sumSqDiff + diff*diff + } + return math.Sqrt(sumSqDiff / float64(len(nums))) +} + +func Median(nums []float64) float64 { + if len(nums) == 0 { + return 0 + } + + sorted := make([]float64, len(nums)) + copy(sorted, nums) + for i := 0; i < len(sorted); i++ { + for j := i + 1; j < len(sorted); j++ { + if sorted[j] < sorted[i] { + sorted[i], sorted[j] = sorted[j], sorted[i] + } + } + } + + mid := len(sorted) / 2 + if len(sorted)%2 == 0 { + return (sorted[mid-1] + sorted[mid]) / 2 + } + return sorted[mid] +} + +func NthRoot(x float64, n int) float64 { + if n <= 0 { + return 0 + } + if x < 0 && n%2 == 0 { + return 0 + } + + guess := x / float64(n) + for i := 0; i < 1000; i++ { + powered := Power(guess, n-1) + if powered == 0 { + break + } + guess = guess - (Power(guess, n)-x)/(float64(n)*powered) + } + return guess +} + +func Combinations(n, k int) int64 { + if k < 0 || k > n { + return 0 + } + if k == 0 || k == n { + return 1 + } + return Factorial(n) / (Factorial(k) * Factorial(n-k)) +} + +func Permutations(n, k int) int64 { + if k < 0 || k > n { + return 0 + } + return Factorial(n) / Factorial(n-k) +} diff --git a/code_to_optimize/go/calculator_test.go b/code_to_optimize/go/calculator_test.go new file mode 100644 index 000000000..0331695c1 --- /dev/null +++ b/code_to_optimize/go/calculator_test.go @@ -0,0 +1,149 @@ +package sample + +import ( + "math" + "testing" +) + +func TestFactorial(t *testing.T) { + tests := []struct { + n int + want int64 + }{ + {0, 1}, + {1, 1}, + {5, 120}, + {10, 3628800}, + } + + for _, tc := range tests { + got := Factorial(tc.n) + if got != tc.want { + t.Errorf("Factorial(%d) = %d, want %d", tc.n, got, tc.want) + } + } +} + +func TestFactorialPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for negative input") + } + }() + Factorial(-1) +} + +func TestPower(t *testing.T) { + tests := []struct { + base float64 + exp int + want float64 + }{ + {2, 10, 1024}, + {3, 0, 1}, + {5, 1, 5}, + {2, -1, 0.5}, + } + + for _, tc := range tests { + got := Power(tc.base, tc.exp) + if math.Abs(got-tc.want) > 1e-9 { + t.Errorf("Power(%f, %d) = %f, want %f", tc.base, tc.exp, got, tc.want) + } + } +} + +func TestSumRange(t *testing.T) { + if got := SumRange(1, 100); got != 5050 { + t.Errorf("SumRange(1,100) = %d, want 5050", got) + } + if got := SumRange(5, 5); got != 5 { + t.Errorf("SumRange(5,5) = %d, want 5", got) + } +} + +func TestAverage(t *testing.T) { + got := Average([]float64{1, 2, 3, 4, 5}) + if got != 3.0 { + t.Errorf("Average = %f, want 3.0", got) + } + + got = Average([]float64{}) + if got != 0 { + t.Errorf("Average empty = %f, want 0", got) + } +} + +func TestStandardDeviation(t *testing.T) { + got := StandardDeviation([]float64{2, 4, 4, 4, 5, 5, 7, 9}) + if math.Abs(got-2.0) > 0.01 { + t.Errorf("StandardDeviation = %f, want ~2.0", got) + } +} + +func TestMedian(t *testing.T) { + got := Median([]float64{3, 1, 2}) + if got != 2.0 { + t.Errorf("Median odd = %f, want 2.0", got) + } + + got = Median([]float64{4, 1, 3, 2}) + if got != 2.5 { + t.Errorf("Median even = %f, want 2.5", got) + } + + got = Median([]float64{}) + if got != 0 { + t.Errorf("Median empty = %f, want 0", got) + } +} + +func TestNthRoot(t *testing.T) { + got := NthRoot(27, 3) + if math.Abs(got-3.0) > 1e-6 { + t.Errorf("NthRoot(27,3) = %f, want 3.0", got) + } + + got = NthRoot(16, 4) + if math.Abs(got-2.0) > 1e-6 { + t.Errorf("NthRoot(16,4) = %f, want 2.0", got) + } +} + +func TestCombinations(t *testing.T) { + tests := []struct { + n, k int + want int64 + }{ + {5, 2, 10}, + {10, 3, 120}, + {5, 0, 1}, + {5, 5, 1}, + {3, 5, 0}, + } + + for _, tc := range tests { + got := Combinations(tc.n, tc.k) + if got != tc.want { + t.Errorf("Combinations(%d,%d) = %d, want %d", tc.n, tc.k, got, tc.want) + } + } +} + +func TestPermutations(t *testing.T) { + tests := []struct { + n, k int + want int64 + }{ + {5, 2, 20}, + {5, 0, 1}, + {3, 5, 0}, + } + + for _, tc := range tests { + got := Permutations(tc.n, tc.k) + if got != tc.want { + t.Errorf("Permutations(%d,%d) = %d, want %d", tc.n, tc.k, got, tc.want) + } + } +} diff --git a/code_to_optimize/go/fibonacci.go b/code_to_optimize/go/fibonacci.go new file mode 100644 index 000000000..fa9bb9ad1 --- /dev/null +++ b/code_to_optimize/go/fibonacci.go @@ -0,0 +1,108 @@ +package sample + +import "math" + +func Fibonacci(n int) int64 { + if n < 0 { + panic("fibonacci not defined for negative numbers") + } + if n <= 1 { + return int64(n) + } + return Fibonacci(n-1) + Fibonacci(n-2) +} + +func IsFibonacci(num int64) bool { + if num < 0 { + return false + } + check1 := 5*num*num + 4 + check2 := 5*num*num - 4 + return isPerfectSquare(check1) || isPerfectSquare(check2) +} + +func isPerfectSquare(n int64) bool { + if n < 0 { + return false + } + sqrt := int64(math.Sqrt(float64(n))) + return sqrt*sqrt == n +} + +func FibonacciSequence(n int) []int64 { + if n < 0 { + panic("n must be non-negative") + } + if n == 0 { + return []int64{} + } + + result := make([]int64, n) + for i := 0; i < n; i++ { + result[i] = Fibonacci(i) + } + return result +} + +func FibonacciIndex(fibNum int64) int { + if fibNum < 0 { + return -1 + } + if fibNum == 0 { + return 0 + } + if fibNum == 1 { + return 1 + } + + for index := 2; index <= 50; index++ { + fib := Fibonacci(index) + if fib == fibNum { + return index + } + if fib > fibNum { + return -1 + } + } + return -1 +} + +func SumFibonacci(n int) int64 { + if n <= 0 { + return 0 + } + var sum int64 + for i := 0; i < n; i++ { + sum += Fibonacci(i) + } + return sum +} + +func FibonacciUpTo(limit int64) []int64 { + var result []int64 + if limit <= 0 { + return result + } + + for index := 0; index <= 50; index++ { + fib := Fibonacci(index) + if fib >= limit { + break + } + result = append(result, fib) + } + return result +} + +func AreConsecutiveFibonacci(a, b int64) bool { + if !IsFibonacci(a) || !IsFibonacci(b) { + return false + } + indexA := FibonacciIndex(a) + indexB := FibonacciIndex(b) + diff := indexA - indexB + if diff < 0 { + diff = -diff + } + return diff == 1 +} diff --git a/code_to_optimize/go/fibonacci_test.go b/code_to_optimize/go/fibonacci_test.go new file mode 100644 index 000000000..3189ae754 --- /dev/null +++ b/code_to_optimize/go/fibonacci_test.go @@ -0,0 +1,138 @@ +package sample + +import ( + "reflect" + "testing" +) + +func TestFibonacci(t *testing.T) { + tests := []struct { + n int + want int64 + }{ + {0, 0}, + {1, 1}, + {2, 1}, + {5, 5}, + {10, 55}, + {20, 6765}, + } + + for _, tc := range tests { + got := Fibonacci(tc.n) + if got != tc.want { + t.Errorf("Fibonacci(%d) = %d, want %d", tc.n, got, tc.want) + } + } +} + +func TestFibonacciPanicsOnNegative(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for negative input") + } + }() + Fibonacci(-1) +} + +func TestIsFibonacci(t *testing.T) { + fibs := []int64{0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55} + for _, f := range fibs { + if !IsFibonacci(f) { + t.Errorf("IsFibonacci(%d) = false, want true", f) + } + } + + nonFibs := []int64{4, 6, 7, 9, 10, 22} + for _, f := range nonFibs { + if IsFibonacci(f) { + t.Errorf("IsFibonacci(%d) = true, want false", f) + } + } + + if IsFibonacci(-1) { + t.Error("IsFibonacci(-1) should be false") + } +} + +func TestFibonacciSequence(t *testing.T) { + got := FibonacciSequence(7) + want := []int64{0, 1, 1, 2, 3, 5, 8} + if !reflect.DeepEqual(got, want) { + t.Errorf("FibonacciSequence(7) = %v, want %v", got, want) + } + + got = FibonacciSequence(0) + if len(got) != 0 { + t.Errorf("FibonacciSequence(0) should be empty, got %v", got) + } +} + +func TestFibonacciIndex(t *testing.T) { + tests := []struct { + num int64 + want int + }{ + {0, 0}, + {1, 1}, + {5, 5}, + {8, 6}, + {55, 10}, + {4, -1}, + {-1, -1}, + } + + for _, tc := range tests { + got := FibonacciIndex(tc.num) + if got != tc.want { + t.Errorf("FibonacciIndex(%d) = %d, want %d", tc.num, got, tc.want) + } + } +} + +func TestSumFibonacci(t *testing.T) { + tests := []struct { + n int + want int64 + }{ + {0, 0}, + {1, 0}, + {5, 7}, + {7, 20}, + } + + for _, tc := range tests { + got := SumFibonacci(tc.n) + if got != tc.want { + t.Errorf("SumFibonacci(%d) = %d, want %d", tc.n, got, tc.want) + } + } +} + +func TestFibonacciUpTo(t *testing.T) { + got := FibonacciUpTo(10) + want := []int64{0, 1, 1, 2, 3, 5, 8} + if !reflect.DeepEqual(got, want) { + t.Errorf("FibonacciUpTo(10) = %v, want %v", got, want) + } + + got = FibonacciUpTo(0) + if len(got) != 0 { + t.Errorf("FibonacciUpTo(0) should be empty") + } +} + +func TestAreConsecutiveFibonacci(t *testing.T) { + if !AreConsecutiveFibonacci(5, 8) { + t.Error("5 and 8 are consecutive fibonacci numbers") + } + if !AreConsecutiveFibonacci(8, 5) { + t.Error("8 and 5 are consecutive fibonacci numbers") + } + if AreConsecutiveFibonacci(5, 13) { + t.Error("5 and 13 are not consecutive fibonacci numbers") + } + if AreConsecutiveFibonacci(4, 5) { + t.Error("4 is not a fibonacci number") + } +} diff --git a/code_to_optimize/go/go.mod b/code_to_optimize/go/go.mod new file mode 100644 index 000000000..d45a82bbd --- /dev/null +++ b/code_to_optimize/go/go.mod @@ -0,0 +1,3 @@ +module example/codeflash-go-sample + +go 1.26 diff --git a/code_to_optimize/go/graph.go b/code_to_optimize/go/graph.go new file mode 100644 index 000000000..d91da788c --- /dev/null +++ b/code_to_optimize/go/graph.go @@ -0,0 +1,197 @@ +package sample + +func BFS(graph map[int][]int, start int) []int { + visited := make(map[int]bool) + var result []int + queue := []int{start} + visited[start] = true + + for len(queue) > 0 { + node := queue[0] + queue = queue[1:] + result = append(result, node) + + neighbors := graph[node] + for i := 0; i < len(neighbors); i++ { + for j := i + 1; j < len(neighbors); j++ { + if neighbors[j] < neighbors[i] { + neighbors[i], neighbors[j] = neighbors[j], neighbors[i] + } + } + } + + for _, neighbor := range neighbors { + if !visited[neighbor] { + visited[neighbor] = true + queue = append(queue, neighbor) + } + } + } + return result +} + +func DFS(graph map[int][]int, start int) []int { + visited := make(map[int]bool) + var result []int + dfsHelper(graph, start, visited, &result) + return result +} + +func dfsHelper(graph map[int][]int, node int, visited map[int]bool, result *[]int) { + if visited[node] { + return + } + visited[node] = true + *result = append(*result, node) + + neighbors := make([]int, len(graph[node])) + copy(neighbors, graph[node]) + for i := 0; i < len(neighbors); i++ { + for j := i + 1; j < len(neighbors); j++ { + if neighbors[j] < neighbors[i] { + neighbors[i], neighbors[j] = neighbors[j], neighbors[i] + } + } + } + + for _, neighbor := range neighbors { + dfsHelper(graph, neighbor, visited, result) + } +} + +func ShortestPath(graph map[int][]int, start, end int) int { + if start == end { + return 0 + } + + visited := make(map[int]bool) + type entry struct { + node int + dist int + } + queue := []entry{{start, 0}} + visited[start] = true + + for len(queue) > 0 { + curr := queue[0] + queue = queue[1:] + + for _, neighbor := range graph[curr.node] { + if neighbor == end { + return curr.dist + 1 + } + if !visited[neighbor] { + visited[neighbor] = true + queue = append(queue, entry{neighbor, curr.dist + 1}) + } + } + } + return -1 +} + +func HasCycle(graph map[int][]int) bool { + visited := make(map[int]bool) + recStack := make(map[int]bool) + + for node := range graph { + if hasCycleDFS(graph, node, visited, recStack) { + return true + } + } + return false +} + +func hasCycleDFS(graph map[int][]int, node int, visited, recStack map[int]bool) bool { + if recStack[node] { + return true + } + if visited[node] { + return false + } + + visited[node] = true + recStack[node] = true + + for _, neighbor := range graph[node] { + if hasCycleDFS(graph, neighbor, visited, recStack) { + return true + } + } + + recStack[node] = false + return false +} + +func TopologicalSort(graph map[int][]int) []int { + inDegree := make(map[int]int) + for node := range graph { + if _, ok := inDegree[node]; !ok { + inDegree[node] = 0 + } + for _, neighbor := range graph[node] { + inDegree[neighbor]++ + } + } + + var queue []int + for node, degree := range inDegree { + if degree == 0 { + queue = append(queue, node) + } + } + + for i := 0; i < len(queue); i++ { + for j := i + 1; j < len(queue); j++ { + if queue[j] < queue[i] { + queue[i], queue[j] = queue[j], queue[i] + } + } + } + + var result []int + for len(queue) > 0 { + node := queue[0] + queue = queue[1:] + result = append(result, node) + + for _, neighbor := range graph[node] { + inDegree[neighbor]-- + if inDegree[neighbor] == 0 { + queue = append(queue, neighbor) + for i := 0; i < len(queue); i++ { + for j := i + 1; j < len(queue); j++ { + if queue[j] < queue[i] { + queue[i], queue[j] = queue[j], queue[i] + } + } + } + } + } + } + return result +} + +func ConnectedComponents(graph map[int][]int) [][]int { + visited := make(map[int]bool) + var components [][]int + + for node := range graph { + if !visited[node] { + var component []int + componentDFS(graph, node, visited, &component) + components = append(components, component) + } + } + return components +} + +func componentDFS(graph map[int][]int, node int, visited map[int]bool, component *[]int) { + if visited[node] { + return + } + visited[node] = true + *component = append(*component, node) + for _, neighbor := range graph[node] { + componentDFS(graph, neighbor, visited, component) + } +} diff --git a/code_to_optimize/go/graph_test.go b/code_to_optimize/go/graph_test.go new file mode 100644 index 000000000..d33a38fe0 --- /dev/null +++ b/code_to_optimize/go/graph_test.go @@ -0,0 +1,109 @@ +package sample + +import ( + "reflect" + "testing" +) + +func TestBFS(t *testing.T) { + graph := map[int][]int{ + 0: {1, 2}, + 1: {3}, + 2: {3}, + 3: {}, + } + got := BFS(graph, 0) + want := []int{0, 1, 2, 3} + if !reflect.DeepEqual(got, want) { + t.Errorf("BFS = %v, want %v", got, want) + } +} + +func TestBFSSingleNode(t *testing.T) { + graph := map[int][]int{0: {}} + got := BFS(graph, 0) + want := []int{0} + if !reflect.DeepEqual(got, want) { + t.Errorf("BFS single = %v, want %v", got, want) + } +} + +func TestDFS(t *testing.T) { + graph := map[int][]int{ + 0: {1, 2}, + 1: {3}, + 2: {3}, + 3: {}, + } + got := DFS(graph, 0) + want := []int{0, 1, 3, 2} + if !reflect.DeepEqual(got, want) { + t.Errorf("DFS = %v, want %v", got, want) + } +} + +func TestShortestPath(t *testing.T) { + graph := map[int][]int{ + 0: {1, 2}, + 1: {3}, + 2: {3}, + 3: {}, + } + + if got := ShortestPath(graph, 0, 3); got != 2 { + t.Errorf("ShortestPath(0,3) = %d, want 2", got) + } + if got := ShortestPath(graph, 0, 0); got != 0 { + t.Errorf("ShortestPath(0,0) = %d, want 0", got) + } + if got := ShortestPath(graph, 3, 0); got != -1 { + t.Errorf("ShortestPath(3,0) = %d, want -1", got) + } +} + +func TestHasCycle(t *testing.T) { + acyclic := map[int][]int{ + 0: {1}, + 1: {2}, + 2: {}, + } + if HasCycle(acyclic) { + t.Error("expected no cycle in DAG") + } + + cyclic := map[int][]int{ + 0: {1}, + 1: {2}, + 2: {0}, + } + if !HasCycle(cyclic) { + t.Error("expected cycle") + } +} + +func TestTopologicalSort(t *testing.T) { + graph := map[int][]int{ + 0: {1, 2}, + 1: {3}, + 2: {3}, + 3: {}, + } + got := TopologicalSort(graph) + want := []int{0, 1, 2, 3} + if !reflect.DeepEqual(got, want) { + t.Errorf("TopologicalSort = %v, want %v", got, want) + } +} + +func TestConnectedComponents(t *testing.T) { + graph := map[int][]int{ + 0: {1}, + 1: {0}, + 2: {3}, + 3: {2}, + } + components := ConnectedComponents(graph) + if len(components) != 2 { + t.Errorf("expected 2 components, got %d", len(components)) + } +} diff --git a/code_to_optimize/go/matrix.go b/code_to_optimize/go/matrix.go new file mode 100644 index 000000000..0dd7e1179 --- /dev/null +++ b/code_to_optimize/go/matrix.go @@ -0,0 +1,122 @@ +package sample + +import "math" + +func MatrixMultiply(a, b [][]float64) [][]float64 { + if len(a) == 0 || len(b) == 0 { + return nil + } + + rows := len(a) + cols := len(b[0]) + inner := len(b) + + result := make([][]float64, rows) + for i := range result { + result[i] = make([]float64, cols) + } + + for i := 0; i < rows; i++ { + for j := 0; j < cols; j++ { + sum := 0.0 + for k := 0; k < inner; k++ { + sum = sum + a[i][k]*b[k][j] + } + result[i][j] = sum + } + } + return result +} + +func MatrixTranspose(m [][]float64) [][]float64 { + if len(m) == 0 { + return nil + } + + rows := len(m) + cols := len(m[0]) + + result := make([][]float64, cols) + for i := range result { + result[i] = make([]float64, rows) + } + + for i := 0; i < rows; i++ { + for j := 0; j < cols; j++ { + result[j][i] = m[i][j] + } + } + return result +} + +func MatrixAdd(a, b [][]float64) [][]float64 { + if len(a) == 0 || len(b) == 0 { + return nil + } + + rows := len(a) + cols := len(a[0]) + + result := make([][]float64, rows) + for i := range result { + result[i] = make([]float64, cols) + for j := 0; j < cols; j++ { + result[i][j] = a[i][j] + b[i][j] + } + } + return result +} + +func MatrixScale(m [][]float64, scalar float64) [][]float64 { + if len(m) == 0 { + return nil + } + + rows := len(m) + cols := len(m[0]) + + result := make([][]float64, rows) + for i := range result { + result[i] = make([]float64, cols) + for j := 0; j < cols; j++ { + result[i][j] = m[i][j] * scalar + } + } + return result +} + +func DotProduct(a, b []float64) float64 { + sum := 0.0 + for i := 0; i < len(a); i++ { + sum = sum + a[i]*b[i] + } + return sum +} + +func VectorNorm(v []float64) float64 { + sum := 0.0 + for _, val := range v { + sum = sum + val*val + } + return math.Sqrt(sum) +} + +func CosineSimilarity(a, b []float64) float64 { + dot := DotProduct(a, b) + normA := VectorNorm(a) + normB := VectorNorm(b) + if normA == 0 || normB == 0 { + return 0 + } + return dot / (normA * normB) +} + +func FlattenMatrix(m [][]float64) []float64 { + var result []float64 + for _, row := range m { + for _, val := range row { + result = append(result, val) + } + } + return result +} diff --git a/code_to_optimize/go/matrix_test.go b/code_to_optimize/go/matrix_test.go new file mode 100644 index 000000000..90471c5fa --- /dev/null +++ b/code_to_optimize/go/matrix_test.go @@ -0,0 +1,112 @@ +package sample + +import ( + "math" + "reflect" + "testing" +) + +func TestMatrixMultiply(t *testing.T) { + a := [][]float64{{1, 2}, {3, 4}} + b := [][]float64{{5, 6}, {7, 8}} + got := MatrixMultiply(a, b) + want := [][]float64{{19, 22}, {43, 50}} + if !reflect.DeepEqual(got, want) { + t.Errorf("MatrixMultiply = %v, want %v", got, want) + } +} + +func TestMatrixMultiplyEmpty(t *testing.T) { + got := MatrixMultiply([][]float64{}, [][]float64{{1}}) + if got != nil { + t.Errorf("expected nil for empty input, got %v", got) + } +} + +func TestMatrixMultiplyIdentity(t *testing.T) { + a := [][]float64{{1, 2, 3}, {4, 5, 6}} + identity := [][]float64{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}} + got := MatrixMultiply(a, identity) + if !reflect.DeepEqual(got, a) { + t.Errorf("A * I = %v, want %v", got, a) + } +} + +func TestMatrixTranspose(t *testing.T) { + m := [][]float64{{1, 2, 3}, {4, 5, 6}} + got := MatrixTranspose(m) + want := [][]float64{{1, 4}, {2, 5}, {3, 6}} + if !reflect.DeepEqual(got, want) { + t.Errorf("MatrixTranspose = %v, want %v", got, want) + } +} + +func TestMatrixTransposeEmpty(t *testing.T) { + got := MatrixTranspose([][]float64{}) + if got != nil { + t.Errorf("expected nil for empty input") + } +} + +func TestMatrixAdd(t *testing.T) { + a := [][]float64{{1, 2}, {3, 4}} + b := [][]float64{{5, 6}, {7, 8}} + got := MatrixAdd(a, b) + want := [][]float64{{6, 8}, {10, 12}} + if !reflect.DeepEqual(got, want) { + t.Errorf("MatrixAdd = %v, want %v", got, want) + } +} + +func TestMatrixScale(t *testing.T) { + m := [][]float64{{1, 2}, {3, 4}} + got := MatrixScale(m, 2.0) + want := [][]float64{{2, 4}, {6, 8}} + if !reflect.DeepEqual(got, want) { + t.Errorf("MatrixScale = %v, want %v", got, want) + } +} + +func TestDotProduct(t *testing.T) { + got := DotProduct([]float64{1, 2, 3}, []float64{4, 5, 6}) + want := 32.0 + if got != want { + t.Errorf("DotProduct = %f, want %f", got, want) + } +} + +func TestVectorNorm(t *testing.T) { + got := VectorNorm([]float64{3, 4}) + want := 5.0 + if got != want { + t.Errorf("VectorNorm = %f, want %f", got, want) + } +} + +func TestCosineSimilarity(t *testing.T) { + a := []float64{1, 0} + b := []float64{0, 1} + got := CosineSimilarity(a, b) + if math.Abs(got) > 1e-9 { + t.Errorf("orthogonal vectors should have cosine similarity 0, got %f", got) + } + + got = CosineSimilarity([]float64{1, 2, 3}, []float64{1, 2, 3}) + if math.Abs(got-1.0) > 1e-9 { + t.Errorf("identical vectors should have cosine similarity 1, got %f", got) + } + + got = CosineSimilarity([]float64{0, 0}, []float64{1, 2}) + if got != 0 { + t.Errorf("zero vector should give 0, got %f", got) + } +} + +func TestFlattenMatrix(t *testing.T) { + m := [][]float64{{1, 2}, {3, 4}, {5, 6}} + got := FlattenMatrix(m) + want := []float64{1, 2, 3, 4, 5, 6} + if !reflect.DeepEqual(got, want) { + t.Errorf("FlattenMatrix = %v, want %v", got, want) + } +} diff --git a/code_to_optimize/go/sorter/sorting.go b/code_to_optimize/go/sorter/sorting.go new file mode 100644 index 000000000..22e821561 --- /dev/null +++ b/code_to_optimize/go/sorter/sorting.go @@ -0,0 +1,103 @@ +package sorter + +func BubbleSort(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + // Standard optimized bubble sort: + // - reduce inner loop bound each pass (last elements are already sorted) + // - stop early if no swaps occurred in a pass + for i := 0; i < n-1; i++ { + swapped := false + // after i passes, the last i elements are in place + limit := n - 1 - i + for j := 0; j < limit; j++ { + if result[j] > result[j+1] { + // swap + result[j], result[j+1] = result[j+1], result[j] + swapped = true + } + } + if !swapped { + break + } + } + return result +} + +func BubbleSortDescending(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 0; i < n-1; i++ { + for j := 0; j < n-i-1; j++ { + if result[j] < result[j+1] { + temp := result[j] + result[j] = result[j+1] + result[j+1] = temp + } + } + } + return result +} + +func InsertionSort(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 1; i < n; i++ { + key := result[i] + j := i - 1 + for j >= 0 && result[j] > key { + result[j+1] = result[j] + j-- + } + result[j+1] = key + } + return result +} + +func SelectionSort(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 0; i < n-1; i++ { + minIdx := i + for j := i + 1; j < n; j++ { + if result[j] < result[minIdx] { + minIdx = j + } + } + result[minIdx], result[i] = result[i], result[minIdx] + } + return result +} + +func IsSorted(arr []int) bool { + for i := 0; i < len(arr)-1; i++ { + if arr[i] > arr[i+1] { + return false + } + } + return true +} diff --git a/code_to_optimize/go/sorting.go b/code_to_optimize/go/sorting.go new file mode 100644 index 000000000..7de2a322e --- /dev/null +++ b/code_to_optimize/go/sorting.go @@ -0,0 +1,94 @@ +package sample + +func BubbleSort(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 0; i < n; i++ { + for j := 0; j < n-1; j++ { + if result[j] > result[j+1] { + temp := result[j] + result[j] = result[j+1] + result[j+1] = temp + } + } + } + return result +} + +func BubbleSortDescending(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 0; i < n-1; i++ { + for j := 0; j < n-i-1; j++ { + if result[j] < result[j+1] { + temp := result[j] + result[j] = result[j+1] + result[j+1] = temp + } + } + } + return result +} + +func InsertionSort(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 1; i < n; i++ { + key := result[i] + j := i - 1 + for j >= 0 && result[j] > key { + result[j+1] = result[j] + j-- + } + result[j+1] = key + } + return result +} + +func SelectionSort(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 0; i < n-1; i++ { + minIdx := i + for j := i + 1; j < n; j++ { + if result[j] < result[minIdx] { + minIdx = j + } + } + result[minIdx], result[i] = result[i], result[minIdx] + } + return result +} + +func IsSorted(arr []int) bool { + for i := 0; i < len(arr)-1; i++ { + if arr[i] > arr[i+1] { + return false + } + } + return true +} diff --git a/code_to_optimize/go/sorting_test.go b/code_to_optimize/go/sorting_test.go new file mode 100644 index 000000000..ac6890693 --- /dev/null +++ b/code_to_optimize/go/sorting_test.go @@ -0,0 +1,122 @@ +package sample + +import ( + "reflect" + "testing" +) + +func TestBubbleSort(t *testing.T) { + tests := []struct { + input []int + expected []int + }{ + {[]int{5, 3, 1, 4, 2}, []int{1, 2, 3, 4, 5}}, + {[]int{3, 2, 1}, []int{1, 2, 3}}, + {[]int{1}, []int{1}}, + {[]int{}, []int{}}, + {[]int{1, 2, 3, 4, 5}, []int{1, 2, 3, 4, 5}}, + } + + for _, tc := range tests { + result := BubbleSort(tc.input) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("BubbleSort(%v) = %v, want %v", tc.input, result, tc.expected) + } + } +} + +func TestBubbleSortWithDuplicates(t *testing.T) { + result := BubbleSort([]int{3, 2, 4, 1, 3, 2}) + expected := []int{1, 2, 2, 3, 3, 4} + if !reflect.DeepEqual(result, expected) { + t.Errorf("got %v, want %v", result, expected) + } +} + +func TestBubbleSortWithNegatives(t *testing.T) { + result := BubbleSort([]int{3, -2, 7, 0, -5}) + expected := []int{-5, -2, 0, 3, 7} + if !reflect.DeepEqual(result, expected) { + t.Errorf("got %v, want %v", result, expected) + } +} + +func TestBubbleSortDescending(t *testing.T) { + tests := []struct { + input []int + expected []int + }{ + {[]int{1, 3, 5, 2, 4}, []int{5, 4, 3, 2, 1}}, + {[]int{1, 2, 3}, []int{3, 2, 1}}, + {[]int{}, []int{}}, + } + + for _, tc := range tests { + result := BubbleSortDescending(tc.input) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("BubbleSortDescending(%v) = %v, want %v", tc.input, result, tc.expected) + } + } +} + +func TestInsertionSort(t *testing.T) { + tests := []struct { + input []int + expected []int + }{ + {[]int{5, 3, 1, 4, 2}, []int{1, 2, 3, 4, 5}}, + {[]int{3, 2, 1}, []int{1, 2, 3}}, + {[]int{1}, []int{1}}, + {[]int{}, []int{}}, + } + + for _, tc := range tests { + result := InsertionSort(tc.input) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("InsertionSort(%v) = %v, want %v", tc.input, result, tc.expected) + } + } +} + +func TestSelectionSort(t *testing.T) { + tests := []struct { + input []int + expected []int + }{ + {[]int{5, 3, 1, 4, 2}, []int{1, 2, 3, 4, 5}}, + {[]int{3, 2, 1}, []int{1, 2, 3}}, + {[]int{1}, []int{1}}, + } + + for _, tc := range tests { + result := SelectionSort(tc.input) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("SelectionSort(%v) = %v, want %v", tc.input, result, tc.expected) + } + } +} + +func TestIsSorted(t *testing.T) { + if !IsSorted([]int{1, 2, 3, 4, 5}) { + t.Error("expected sorted") + } + if !IsSorted([]int{1}) { + t.Error("expected sorted") + } + if !IsSorted([]int{}) { + t.Error("expected sorted") + } + if IsSorted([]int{5, 3, 1}) { + t.Error("expected not sorted") + } +} + +func TestBubbleSortDoesNotMutateInput(t *testing.T) { + original := []int{5, 3, 1, 4, 2} + saved := make([]int, len(original)) + copy(saved, original) + BubbleSort(original) + if !reflect.DeepEqual(original, saved) { + t.Errorf("input was mutated: got %v, want %v", original, saved) + } +} diff --git a/code_to_optimize/go/stringutils.go b/code_to_optimize/go/stringutils.go new file mode 100644 index 000000000..5ee8166de --- /dev/null +++ b/code_to_optimize/go/stringutils.go @@ -0,0 +1,125 @@ +package sample + +import "strings" + +func ReverseString(s string) string { + result := "" + for i := len(s) - 1; i >= 0; i-- { + result = result + string(s[i]) + } + return result +} + +func IsPalindrome(s string) bool { + reversed := ReverseString(s) + return s == reversed +} + +func CountWords(s string) int { + trimmed := strings.TrimSpace(s) + if trimmed == "" { + return 0 + } + return len(strings.Fields(trimmed)) +} + +func CapitalizeWords(s string) string { + if s == "" { + return s + } + + words := strings.Split(s, " ") + result := "" + + for i, word := range words { + if len(word) > 0 { + capitalized := strings.ToUpper(word[:1]) + strings.ToLower(word[1:]) + result = result + capitalized + } + if i < len(words)-1 { + result = result + " " + } + } + return result +} + +func CountOccurrences(s, sub string) int { + if sub == "" { + return 0 + } + + count := 0 + index := 0 + for { + pos := strings.Index(s[index:], sub) + if pos == -1 { + break + } + count++ + index = index + pos + 1 + } + return count +} + +func RemoveWhitespace(s string) string { + result := "" + for _, c := range s { + if c != ' ' && c != '\t' && c != '\n' && c != '\r' { + result = result + string(c) + } + } + return result +} + +func FindAllIndices(s string, c byte) []int { + var indices []int + for i := 0; i < len(s); i++ { + if s[i] == c { + indices = append(indices, i) + } + } + return indices +} + +func IsNumeric(s string) bool { + if s == "" { + return false + } + for _, c := range s { + if c < '0' || c > '9' { + return false + } + } + return true +} + +func Repeat(s string, n int) string { + if n <= 0 { + return "" + } + result := "" + for i := 0; i < n; i++ { + result = result + s + } + return result +} + +func Truncate(s string, maxLen int) string { + if maxLen <= 0 { + return "" + } + if len(s) <= maxLen { + return s + } + if maxLen <= 3 { + return s[:maxLen] + } + return s[:maxLen-3] + "..." +} + +func ToTitleCase(s string) string { + if s == "" { + return s + } + return strings.ToUpper(s[:1]) + strings.ToLower(s[1:]) +} diff --git a/code_to_optimize/go/stringutils_test.go b/code_to_optimize/go/stringutils_test.go new file mode 100644 index 000000000..025928c2c --- /dev/null +++ b/code_to_optimize/go/stringutils_test.go @@ -0,0 +1,216 @@ +package sample + +import ( + "reflect" + "testing" +) + +func TestReverseString(t *testing.T) { + tests := []struct { + input, want string + }{ + {"hello", "olleh"}, + {"a", "a"}, + {"", ""}, + {"abcd", "dcba"}, + } + + for _, tc := range tests { + got := ReverseString(tc.input) + if got != tc.want { + t.Errorf("ReverseString(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} + +func TestIsPalindrome(t *testing.T) { + palindromes := []string{"racecar", "madam", "a", "", "abba"} + for _, s := range palindromes { + if !IsPalindrome(s) { + t.Errorf("IsPalindrome(%q) = false, want true", s) + } + } + + nonPalindromes := []string{"hello", "ab"} + for _, s := range nonPalindromes { + if IsPalindrome(s) { + t.Errorf("IsPalindrome(%q) = true, want false", s) + } + } +} + +func TestCountWords(t *testing.T) { + tests := []struct { + input string + want int + }{ + {"hello world test", 3}, + {"hello", 1}, + {"", 0}, + {" ", 0}, + {" multiple spaces between words ", 4}, + } + + for _, tc := range tests { + got := CountWords(tc.input) + if got != tc.want { + t.Errorf("CountWords(%q) = %d, want %d", tc.input, got, tc.want) + } + } +} + +func TestCapitalizeWords(t *testing.T) { + tests := []struct { + input, want string + }{ + {"hello world", "Hello World"}, + {"HELLO", "Hello"}, + {"", ""}, + {"one two three", "One Two Three"}, + } + + for _, tc := range tests { + got := CapitalizeWords(tc.input) + if got != tc.want { + t.Errorf("CapitalizeWords(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} + +func TestCountOccurrences(t *testing.T) { + tests := []struct { + s, sub string + want int + }{ + {"hello hello", "hello", 2}, + {"aaa", "a", 3}, + {"aaa", "aa", 2}, + {"hello", "world", 0}, + {"hello", "", 0}, + } + + for _, tc := range tests { + got := CountOccurrences(tc.s, tc.sub) + if got != tc.want { + t.Errorf("CountOccurrences(%q, %q) = %d, want %d", tc.s, tc.sub, got, tc.want) + } + } +} + +func TestRemoveWhitespace(t *testing.T) { + tests := []struct { + input, want string + }{ + {"hello world", "helloworld"}, + {" a b c ", "abc"}, + {"test", "test"}, + {" ", ""}, + {"", ""}, + } + + for _, tc := range tests { + got := RemoveWhitespace(tc.input) + if got != tc.want { + t.Errorf("RemoveWhitespace(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} + +func TestFindAllIndices(t *testing.T) { + got := FindAllIndices("hello", 'l') + want := []int{2, 3} + if !reflect.DeepEqual(got, want) { + t.Errorf("FindAllIndices(\"hello\", 'l') = %v, want %v", got, want) + } + + got = FindAllIndices("aaa", 'a') + if len(got) != 3 { + t.Errorf("expected 3 indices, got %d", len(got)) + } + + got = FindAllIndices("hello", 'z') + if len(got) != 0 { + t.Errorf("expected 0 indices, got %d", len(got)) + } + + got = FindAllIndices("", 'a') + if len(got) != 0 { + t.Errorf("expected 0 indices, got %d", len(got)) + } +} + +func TestIsNumeric(t *testing.T) { + numerics := []string{"12345", "0", "007"} + for _, s := range numerics { + if !IsNumeric(s) { + t.Errorf("IsNumeric(%q) = false, want true", s) + } + } + + nonNumerics := []string{"12.34", "-123", "abc", "12a34", ""} + for _, s := range nonNumerics { + if IsNumeric(s) { + t.Errorf("IsNumeric(%q) = true, want false", s) + } + } +} + +func TestRepeat(t *testing.T) { + tests := []struct { + s string + n int + want string + }{ + {"abc", 3, "abcabcabc"}, + {"a", 3, "aaa"}, + {"abc", 0, ""}, + {"abc", -1, ""}, + } + + for _, tc := range tests { + got := Repeat(tc.s, tc.n) + if got != tc.want { + t.Errorf("Repeat(%q, %d) = %q, want %q", tc.s, tc.n, got, tc.want) + } + } +} + +func TestTruncate(t *testing.T) { + tests := []struct { + s string + maxLen int + want string + }{ + {"hello", 10, "hello"}, + {"hello world", 6, "hel..."}, + {"hello world", 8, "hello..."}, + {"hello", 0, ""}, + {"hello", 3, "hel"}, + } + + for _, tc := range tests { + got := Truncate(tc.s, tc.maxLen) + if got != tc.want { + t.Errorf("Truncate(%q, %d) = %q, want %q", tc.s, tc.maxLen, got, tc.want) + } + } +} + +func TestToTitleCase(t *testing.T) { + tests := []struct { + input, want string + }{ + {"hello", "Hello"}, + {"HELLO", "Hello"}, + {"hELLO", "Hello"}, + {"a", "A"}, + {"", ""}, + } + + for _, tc := range tests { + got := ToTitleCase(tc.input) + if got != tc.want { + t.Errorf("ToTitleCase(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 2db13efe8..cd4d94787 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -116,13 +116,17 @@ def process_pyproject_config(args: Namespace) -> Namespace: # Default to module_root if not specified is_js_ts_project = pyproject_config.get("language") in ("javascript", "typescript") is_java_project = pyproject_config.get("language") == "java" + is_go_project = pyproject_config.get("language") == "go" # Set the test framework singleton for JS/TS projects if is_js_ts_project and pyproject_config.get("test_framework"): set_current_test_framework(pyproject_config["test_framework"]) if args.tests_root is None: - if is_java_project: + if is_go_project: + # this is just a placeholder, in go we put generated test files in the same package as the source + args.tests_root = args.module_root + elif is_java_project: # Try standard Maven/Gradle test directories for test_dir in ["src/test/java", "test", "tests"]: test_path = Path(args.module_root).parent / test_dir if "/" in test_dir else Path(test_dir) @@ -202,7 +206,10 @@ def process_pyproject_config(args: Namespace) -> Namespace: args.benchmarks_root = Path(args.benchmarks_root).resolve() args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path) - if is_java_project and pyproject_file_path.is_dir(): + if is_go_project and pyproject_file_path.is_dir(): + args.project_root = pyproject_file_path.resolve() + args.test_project_root = pyproject_file_path.resolve() + elif is_java_project and pyproject_file_path.is_dir(): # For Java projects, pyproject_file_path IS the project root directory (not a file). # Override project_root which may have resolved to a sub-module. args.project_root = pyproject_file_path.resolve() diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index d4da0ed04..bd44cb761 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -29,6 +29,7 @@ get_suggestions, should_modify_pyproject_toml, ) +from codeflash.cli_cmds.init_go import init_go_project from codeflash.cli_cmds.init_java import init_java_project from codeflash.cli_cmds.init_javascript import ProjectLanguage, detect_project_language, init_js_project from codeflash.code_utils.code_utils import validate_relative_directory_path @@ -61,6 +62,10 @@ def init_codeflash() -> None: # Detect project language project_language = detect_project_language() + if project_language == ProjectLanguage.GO: + init_go_project() + return + if project_language == ProjectLanguage.JAVA: init_java_project() return diff --git a/codeflash/cli_cmds/init_go.py b/codeflash/cli_cmds/init_go.py new file mode 100644 index 000000000..032072231 --- /dev/null +++ b/codeflash/cli_cmds/init_go.py @@ -0,0 +1,188 @@ +"""Go project initialization for Codeflash.""" + +from __future__ import annotations + +import os +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Union + +import click +import inquirer +from git import InvalidGitRepositoryError, Repo +from rich.console import Group +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from codeflash.cli_cmds.console import console +from codeflash.code_utils.compat import LF +from codeflash.code_utils.git_utils import get_git_remotes +from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell +from codeflash.languages.golang.config import detect_go_project, detect_go_version +from codeflash.telemetry.posthog_cf import ph + + +@dataclass(frozen=True) +class GoSetupInfo: + module_root_override: Union[str, None] = None + test_root_override: Union[str, None] = None + formatter_override: Union[list[str], None] = None + git_remote: str = "origin" + disable_telemetry: bool = False + ignore_paths: list[str] | None = None + + +def _get_theme() -> Any: + from codeflash.cli_cmds.init_config import CodeflashTheme + + return CodeflashTheme() + + +def init_go_project() -> None: + from codeflash.cli_cmds.github_workflow import install_github_actions + from codeflash.cli_cmds.init_auth import install_github_app, prompt_api_key + + lang_panel = Panel( + Text( + "Go project detected!\n\nI'll help you set up Codeflash for your project.", style="cyan", justify="center" + ), + title="Go Setup", + border_style="bright_cyan", + ) + console.print(lang_panel) + console.print() + + did_add_new_key = prompt_api_key() + + setup_info = collect_go_setup_info() + git_remote = setup_info.git_remote or "origin" + + install_github_app(git_remote) + + install_github_actions(override_formatter_check=True) + + usage_table = Table(show_header=False, show_lines=False, border_style="dim") + usage_table.add_column("Command", style="cyan") + usage_table.add_column("Description", style="white") + + usage_table.add_row("codeflash --file --function ", "Optimize a specific function") + usage_table.add_row("codeflash --all", "Optimize all functions in all files") + usage_table.add_row("codeflash --help", "See all available options") + + completion_message = "Codeflash is now set up for your Go project!\n\nYou can now run any of these commands:" + + if did_add_new_key: + completion_message += ( + "\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!" + ) + if os.name == "nt": + reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}" + else: + reload_cmd = f"source {get_shell_rc_path()}" + completion_message += f"\nOr run: {reload_cmd}" + + completion_panel = Panel( + Group(Text(completion_message, style="bold green"), Text(""), usage_table), + title="Setup Complete!", + border_style="bright_green", + padding=(1, 2), + ) + console.print(completion_panel) + + ph("cli-go-installation-successful", {"did_add_new_key": did_add_new_key}) + sys.exit(0) + + +def collect_go_setup_info() -> GoSetupInfo: + + from codeflash.cli_cmds.init_config import ask_for_telemetry + + curdir = Path.cwd() + + if not os.access(curdir, os.W_OK): + click.echo(f"The current directory isn't writable, please check your folder permissions and try again.{LF}") + sys.exit(1) + + config = detect_go_project(curdir) + module_path = config.module_path if config else "unknown" + go_version = (config.go_version if config else None) or detect_go_version() or "unknown" + has_vendor = config.has_vendor if config else False + + detection_table = Table(show_header=False, box=None, padding=(0, 2)) + detection_table.add_column("Setting", style="cyan") + detection_table.add_column("Value", style="green") + detection_table.add_row("Module", module_path) + detection_table.add_row("Go version", go_version) + detection_table.add_row("Source root", ".") + detection_table.add_row("Test root", ". (co-located)") + detection_table.add_row("Formatter", "gofmt") + if has_vendor: + detection_table.add_row("Vendor", "yes (vendor/ detected)") + + detection_panel = Panel( + Group(Text("Auto-detected settings for your Go project:\n", style="cyan"), detection_table), + title="Auto-Detection Results", + border_style="bright_blue", + ) + console.print(detection_panel) + console.print() + + git_remote = _get_git_remote_for_setup() + + disable_telemetry = not ask_for_telemetry() + + return GoSetupInfo(git_remote=git_remote, disable_telemetry=disable_telemetry) + + +def _get_git_remote_for_setup() -> str: + try: + repo = Repo(Path.cwd(), search_parent_directories=True) + git_remotes = get_git_remotes(repo) + if not git_remotes: + return "" + + if len(git_remotes) == 1: + return git_remotes[0] + + git_panel = Panel( + Text( + "Configure Git Remote for Pull Requests.\n\nCodeflash will use this remote to create pull requests.", + style="blue", + ), + title="Git Remote Setup", + border_style="bright_blue", + ) + console.print(git_panel) + console.print() + + git_questions = [ + inquirer.List( + "git_remote", + message="Which git remote should Codeflash use?", + choices=git_remotes, + default="origin", + carousel=True, + ) + ] + + git_answers = inquirer.prompt(git_questions, theme=_get_theme()) + return git_answers["git_remote"] if git_answers else git_remotes[0] + except InvalidGitRepositoryError: + return "" + + +def get_go_runtime_setup_steps() -> str: + return """- name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: 'stable'""" + + +def get_go_dependency_installation_commands() -> str: + return "go mod download" + + +def get_go_test_command() -> str: + return "go test ./..." diff --git a/codeflash/cli_cmds/init_javascript.py b/codeflash/cli_cmds/init_javascript.py index fcd3c4b57..20f76d249 100644 --- a/codeflash/cli_cmds/init_javascript.py +++ b/codeflash/cli_cmds/init_javascript.py @@ -38,6 +38,7 @@ class ProjectLanguage(Enum): JAVASCRIPT = auto() TYPESCRIPT = auto() JAVA = auto() + GO = auto() class JsPackageManager(Enum): @@ -89,6 +90,10 @@ def detect_project_language(project_root: Path | None = None) -> ProjectLanguage """ root = project_root or Path.cwd() + # Go detection (go.mod is definitive) + if (root / "go.mod").exists(): + return ProjectLanguage.GO + # Java detection (pom.xml or build.gradle is definitive) has_pom = (root / "pom.xml").exists() has_gradle = (root / "build.gradle").exists() or (root / "build.gradle.kts").exists() diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index 196779589..87960fa3f 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -12,6 +12,28 @@ ALL_CONFIG_FILES: dict[Path, dict[str, Path]] = {} +def _try_parse_go_config() -> tuple[dict[str, Any], Path] | None: + dir_path = Path.cwd() + while dir_path != dir_path.parent: + if (dir_path / "go.mod").exists(): + module_root = str(dir_path.resolve()) + return { + "language": "go", + "module_root": module_root, + "tests_root": module_root, + "pytest_cmd": "pytest", + "git_remote": "origin", + "disable_telemetry": False, + "disable_imports_sorting": False, + "override_fixtures": False, + "benchmark": False, + "formatter_cmds": [], + "ignore_paths": [], + }, dir_path + dir_path = dir_path.parent + return None + + def _try_parse_java_build_config() -> tuple[dict[str, Any], Path] | None: """Detect Java project from build files and parse config from pom.xml/gradle.properties. @@ -106,11 +128,23 @@ def find_conftest_files(test_paths: list[Path]) -> list[Path]: def parse_config_file( config_file_path: Path | None = None, override_formatter_check: bool = False ) -> tuple[dict[str, Any], Path]: - # Detect all config sources — Java build files, package.json, pyproject.toml + # Detect all config sources — Go modules, Java build files, package.json, pyproject.toml + go_result = _try_parse_go_config() if config_file_path is None else None java_result = _try_parse_java_build_config() if config_file_path is None else None package_json_path = find_package_json(config_file_path) pyproject_toml_path = find_closest_config_file("pyproject.toml") if config_file_path is None else None + # Use Go config only if no closer config exists + if go_result is not None: + go_depth = len(go_result[1].parts) + has_closer = ( + (java_result is not None and len(java_result[1].parts) >= go_depth) + or (package_json_path is not None and len(package_json_path.parent.parts) >= go_depth) + or (pyproject_toml_path is not None and len(pyproject_toml_path.parent.parts) >= go_depth) + ) + if not has_closer: + return go_result + # Use Java config only if no closer JS/Python config exists (monorepo support). # In a monorepo with a parent pom.xml and a child package.json, the closer config wins. if java_result is not None: diff --git a/codeflash/languages/__init__.py b/codeflash/languages/__init__.py index b0daea0fb..0ec0f87fd 100644 --- a/codeflash/languages/__init__.py +++ b/codeflash/languages/__init__.py @@ -31,6 +31,7 @@ from codeflash.languages.current import ( current_language, current_language_support, + is_go, is_java, is_javascript, is_python, @@ -83,6 +84,10 @@ def __getattr__(name: str): from codeflash.languages.java.support import JavaSupport return JavaSupport + if name == "GoSupport": + from codeflash.languages.golang.support import GoSupport + + return GoSupport msg = f"module {__name__!r} has no attribute {name!r}" raise AttributeError(msg) @@ -106,6 +111,7 @@ def __getattr__(name: str): "get_language_support", "get_supported_extensions", "get_supported_languages", + "is_go", "is_java", "is_javascript", "is_jest", diff --git a/codeflash/languages/current.py b/codeflash/languages/current.py index b9e45d367..8be5fd07a 100644 --- a/codeflash/languages/current.py +++ b/codeflash/languages/current.py @@ -113,6 +113,16 @@ def is_java() -> bool: return _current_language == Language.JAVA +def is_go() -> bool: + """Check if the current language is Go. + + Returns: + True if the current language is Go. + + """ + return _current_language == Language.GO + + def current_language_support() -> LanguageSupport: """Get the LanguageSupport instance for the current language. diff --git a/codeflash/languages/function_optimizer.py b/codeflash/languages/function_optimizer.py index 71ad03b18..859e6ba16 100644 --- a/codeflash/languages/function_optimizer.py +++ b/codeflash/languages/function_optimizer.py @@ -21,7 +21,7 @@ from rich.tree import Tree import codeflash.code_utils._libcst_cache # noqa: F401 -from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient +from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import ( @@ -78,6 +78,7 @@ AdaptiveOptimizedCandidate, AIServiceAdaptiveOptimizeRequest, AIServiceCodeRepairRequest, + AIServiceRefinerRequest, BestOptimization, CandidateEvaluationContext, GeneratedTests, @@ -1712,7 +1713,9 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio logger.debug(f"Failed to instrument test file {test_file} for performance testing") continue - # For JS/TS, preserve .test.ts or .spec.ts suffix for Jest pattern matching + # Preserve language-specific test file naming conventions: + # JS/TS: .test.ts / .spec.ts for Jest pattern matching + # Go: _test.go required by `go test` def get_instrumented_path(original_path: str, suffix: str) -> Path: path_obj = Path(original_path) stem = path_obj.stem @@ -1724,6 +1727,9 @@ def get_instrumented_path(original_path: str, suffix: str) -> Path: elif ".spec" in stem: base, _ = stem.rsplit(".spec", 1) new_stem = f"{base}{suffix}.spec" + elif stem.endswith("_test") and ext == ".go": + base = stem.removesuffix("_test") + new_stem = f"{base}{suffix}_test" else: new_stem = f"{stem}{suffix}" diff --git a/codeflash/languages/golang/__init__.py b/codeflash/languages/golang/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codeflash/languages/golang/comparator.py b/codeflash/languages/golang/comparator.py new file mode 100644 index 000000000..0a3b23850 --- /dev/null +++ b/codeflash/languages/golang/comparator.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + + +@dataclass +class TestDiff: + test_name: str + original_passed: bool + candidate_passed: bool + message: str + + +def compare_test_results( + original_results_path: Path, + candidate_results_path: Path, + project_root: Path | None = None, + project_classpath: str | None = None, +) -> tuple[bool, list[TestDiff]]: + original = _load_results(original_results_path) + candidate = _load_results(candidate_results_path) + + diffs: list[TestDiff] = [] + + all_tests = set(original.keys()) | set(candidate.keys()) + + for test_name in sorted(all_tests): + orig = original.get(test_name) + cand = candidate.get(test_name) + + if orig is None: + diffs.append( + TestDiff( + test_name=test_name, + original_passed=False, + candidate_passed=cand or False, + message=f"Test {test_name} only present in candidate results", + ) + ) + continue + + if cand is None: + diffs.append( + TestDiff( + test_name=test_name, + original_passed=orig, + candidate_passed=False, + message=f"Test {test_name} missing from candidate results", + ) + ) + continue + + if orig != cand: + diffs.append( + TestDiff( + test_name=test_name, + original_passed=orig, + candidate_passed=cand, + message=f"Test {test_name}: original {'passed' if orig else 'failed'}, candidate {'passed' if cand else 'failed'}", + ) + ) + + are_equivalent = len(diffs) == 0 + return are_equivalent, diffs + + +def _load_results(path: Path) -> dict[str, bool]: + results: dict[str, bool] = {} + try: + content = path.read_text(encoding="utf-8") + except Exception: + logger.debug("Could not read results file %s", path) + return results + + for line in content.splitlines(): + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + except json.JSONDecodeError: + continue + + action = event.get("Action") + test_name = event.get("Test") + if test_name is None: + continue + + if action == "pass": + results[test_name] = True + elif action == "fail": + results[test_name] = False + + return results diff --git a/codeflash/languages/golang/config.py b/codeflash/languages/golang/config.py new file mode 100644 index 000000000..25eae4cce --- /dev/null +++ b/codeflash/languages/golang/config.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import logging +import re +import subprocess +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class GoProjectConfig: + project_root: Path + module_path: str + go_version: str | None = None + has_vendor: bool = False + + +def detect_go_project(project_root: Path) -> GoProjectConfig | None: + go_mod = project_root / "go.mod" + if not go_mod.exists(): + return None + + module_path = "" + go_version = None + + try: + content = go_mod.read_text(encoding="utf-8") + for line in content.splitlines(): + line = line.strip() + if line.startswith("module "): + module_path = line[len("module ") :].strip() + elif line.startswith("go "): + go_version = line[len("go ") :].strip() + except (OSError, UnicodeDecodeError): + logger.warning("Failed to read go.mod at %s", go_mod) + return None + + has_vendor = (project_root / "vendor").is_dir() + + return GoProjectConfig( + project_root=project_root, module_path=module_path, go_version=go_version, has_vendor=has_vendor + ) + + +def detect_go_version() -> str | None: + try: + result = subprocess.run(["go", "version"], capture_output=True, text=True, timeout=10, check=False) + if result.returncode != 0: + return None + match = re.search(r"go(\d+\.\d+(?:\.\d+)?)", result.stdout) + if match: + return match.group(1) + except (FileNotFoundError, subprocess.TimeoutExpired, OSError): + pass + return None + + +def is_go_project(project_root: Path) -> bool: + if (project_root / "go.mod").exists(): + return True + return any(project_root.glob("*.go")) diff --git a/codeflash/languages/golang/context.py b/codeflash/languages/golang/context.py new file mode 100644 index 000000000..eec372e5d --- /dev/null +++ b/codeflash/languages/golang/context.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash.languages.base import CodeContext, HelperFunction +from codeflash.languages.golang.parser import GoAnalyzer +from codeflash.languages.language_enum import Language + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + +logger = logging.getLogger(__name__) + + +def extract_code_context( + function: FunctionToOptimize, + project_root: Path, + module_root: Path | None = None, + analyzer: GoAnalyzer | None = None, +) -> CodeContext: + analyzer = analyzer or GoAnalyzer() + + try: + source = function.file_path.read_text(encoding="utf-8") + except Exception: + logger.exception("Failed to read %s", function.file_path) + return CodeContext(target_code="", target_file=function.file_path, language=Language.GO) + + receiver_type = _get_receiver_type(function) + target_code = analyzer.extract_function_source(source, function.function_name, receiver_type=receiver_type) + if target_code is None: + target_code = "" + + imports = analyzer.find_imports(source) + import_lines = [_import_to_line(imp) for imp in imports] + + read_only_parts: list[str] = [] + if receiver_type: + struct_ctx = _extract_struct_context(source, receiver_type, analyzer) + if struct_ctx: + read_only_parts.append(struct_ctx) + + init_ctx = _extract_init_context(source, analyzer) + if init_ctx: + read_only_parts.append(init_ctx) + + helpers = find_helper_functions(source, function, analyzer) + + return CodeContext( + target_code=target_code, + target_file=function.file_path, + helper_functions=helpers, + read_only_context="\n\n".join(read_only_parts), + imports=import_lines, + language=Language.GO, + ) + + +def find_helper_functions( + source: str, function: FunctionToOptimize, analyzer: GoAnalyzer | None = None +) -> list[HelperFunction]: + analyzer = analyzer or GoAnalyzer() + target_name = function.function_name + receiver_type = _get_receiver_type(function) + + all_functions = analyzer.find_functions(source) + all_methods = analyzer.find_methods(source) + + candidate_names: set[str] = set() + for func in all_functions: + if func.name not in ("init", "main") and func.name != target_name: + candidate_names.add(func.name) + for method in all_methods: + if not (method.name == target_name and method.receiver_name == receiver_type): + candidate_names.add(method.name) + + referenced = analyzer.collect_body_identifiers(source, target_name, receiver_type=receiver_type) + needed = referenced & candidate_names + + seen: set[str] = set() + queue = list(needed) + while queue: + name = queue.pop() + if name in seen: + continue + seen.add(name) + ids = analyzer.collect_body_identifiers(source, name) + if not ids: + for method in all_methods: + if method.name == name: + ids = analyzer.collect_body_identifiers(source, name, receiver_type=method.receiver_name) + if ids: + break + for transitive in ids & candidate_names: + if transitive not in seen: + queue.append(transitive) + + helpers: list[HelperFunction] = [] + + for func in all_functions: + if func.name not in seen: + continue + extracted = analyzer.extract_function_source(source, func.name) + if extracted is None: + continue + helpers.append( + HelperFunction( + name=func.name, + qualified_name=func.name, + file_path=function.file_path, + source_code=extracted, + start_line=func.starting_line, + end_line=func.ending_line, + ) + ) + + for method in all_methods: + if method.name not in seen: + continue + extracted = analyzer.extract_function_source(source, method.name, receiver_type=method.receiver_name) + if extracted is None: + continue + qualified = f"{method.receiver_name}.{method.name}" + helpers.append( + HelperFunction( + name=method.name, + qualified_name=qualified, + file_path=function.file_path, + source_code=extracted, + start_line=method.starting_line, + end_line=method.ending_line, + ) + ) + + return helpers + + +def _get_receiver_type(function: FunctionToOptimize) -> str | None: + if function.parents: + return function.parents[0].name + return None + + +def _import_to_line(imp: object) -> str: + path = getattr(imp, "path", "") + alias = getattr(imp, "alias", None) + if alias: + return f'{alias} "{path}"' + return f'"{path}"' + + +def _extract_struct_context(source: str, struct_name: str, analyzer: GoAnalyzer) -> str: + structs = analyzer.find_structs(source) + for s in structs: + if s.name == struct_name: + lines = source.splitlines() + return "\n".join(lines[s.starting_line - 1 : s.ending_line]) + return "" + + +def _extract_init_context(source: str, analyzer: GoAnalyzer) -> str: + init_source = analyzer.extract_function_source(source, "init") + if init_source is None: + return "" + + init_ids = analyzer.collect_body_identifiers(source, "init") + if not init_ids: + return init_source + + parts: list[str] = [] + + for decl in analyzer.find_global_declarations(source): + if init_ids & set(decl.names): + parts.append(decl.source_code) + + for struct in analyzer.find_structs(source): + if struct.name in init_ids: + lines = source.splitlines() + parts.append("\n".join(lines[struct.starting_line - 1 : struct.ending_line])) + + parts.append(init_source) + return "\n\n".join(parts) diff --git a/codeflash/languages/golang/discovery.py b/codeflash/languages/golang/discovery.py new file mode 100644 index 000000000..c2e9ede2a --- /dev/null +++ b/codeflash/languages/golang/discovery.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash.languages.golang.parser import GoAnalyzer + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash.languages.base import FunctionFilterCriteria + from codeflash.languages.golang.parser import GoFunctionNode, GoMethodNode + from codeflash.models.function_types import FunctionToOptimize + + +logger = logging.getLogger(__name__) + +_SKIP_FUNCTION_NAMES = frozenset({"init", "main"}) + + +def discover_functions( + file_path: Path, filter_criteria: FunctionFilterCriteria | None = None, analyzer: GoAnalyzer | None = None +) -> list[FunctionToOptimize]: + try: + source = file_path.read_text(encoding="utf-8") + except (OSError, UnicodeDecodeError): + logger.warning("Failed to read Go file: %s", file_path) + return [] + return discover_functions_from_source(source, file_path, filter_criteria, analyzer) + + +def discover_functions_from_source( + source: str, + file_path: Path, + filter_criteria: FunctionFilterCriteria | None = None, + analyzer: GoAnalyzer | None = None, +) -> list[FunctionToOptimize]: + from codeflash.models.function_types import FunctionParent, FunctionToOptimize + + if analyzer is None: + analyzer = GoAnalyzer() + + results: list[FunctionToOptimize] = [] + + functions = analyzer.find_functions(source) + for func in functions: + if not _should_include_function(func, filter_criteria, file_path): + continue + results.append( + FunctionToOptimize( + function_name=func.name, + file_path=file_path, + parents=[], + starting_line=func.starting_line, + ending_line=func.ending_line, + starting_col=func.starting_col, + ending_col=func.ending_col, + is_async=False, + is_method=False, + language="go", + doc_start_line=func.doc_start_line, + ) + ) + + methods = analyzer.find_methods(source) + for method in methods: + if not _should_include_method(method, filter_criteria, file_path): + continue + results.append( + FunctionToOptimize( + function_name=method.name, + file_path=file_path, + parents=[FunctionParent(name=method.receiver_name, type="StructDef")], + starting_line=method.starting_line, + ending_line=method.ending_line, + starting_col=method.starting_col, + ending_col=method.ending_col, + is_async=False, + is_method=True, + language="go", + doc_start_line=method.doc_start_line, + ) + ) + + return results + + +def _should_include_function(func: GoFunctionNode, criteria: FunctionFilterCriteria | None, file_path: Path) -> bool: + if file_path.name.endswith("_test.go"): + return False + + if func.name in _SKIP_FUNCTION_NAMES: + return False + + if criteria is None: + return True + + if criteria.require_export and not func.is_exported: + return False + + if criteria.require_return and not func.has_return_type: + return False + + if criteria.matches_exclude_patterns(func.name): + return False + + if not criteria.matches_include_patterns(func.name): + return False + + line_count = func.ending_line - func.starting_line + 1 + if criteria.min_lines is not None and line_count < criteria.min_lines: + return False + if criteria.max_lines is not None and line_count > criteria.max_lines: + return False + + return True + + +def _should_include_method(method: GoMethodNode, criteria: FunctionFilterCriteria | None, file_path: Path) -> bool: + if file_path.name.endswith("_test.go"): + return False + + if criteria is None: + return True + + if not criteria.include_methods: + return False + + if criteria.require_export and not method.is_exported: + return False + + if criteria.require_return and not method.has_return_type: + return False + + if criteria.matches_exclude_patterns(method.name): + return False + + if not criteria.matches_include_patterns(method.name): + return False + + line_count = method.ending_line - method.starting_line + 1 + if criteria.min_lines is not None and line_count < criteria.min_lines: + return False + if criteria.max_lines is not None and line_count > criteria.max_lines: + return False + + return True diff --git a/codeflash/languages/golang/formatter.py b/codeflash/languages/golang/formatter.py new file mode 100644 index 000000000..61a0bfe58 --- /dev/null +++ b/codeflash/languages/golang/formatter.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import logging +import shutil +import subprocess +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + + +def format_go_code(source: str, file_path: Path | None = None) -> str: + goimports = _find_go_tool("goimports") + if goimports is not None: + formatted = _run_formatter(goimports, source) + if formatted is not None: + return formatted + + gofmt = _find_go_tool("gofmt") + if gofmt is not None: + formatted = _run_formatter(gofmt, source) + if formatted is not None: + return formatted + + logger.debug("No Go formatter found (goimports/gofmt), returning source unchanged") + return source + + +def _find_go_tool(name: str) -> str | None: + import os + from pathlib import Path + + found = shutil.which(name) + if found: + return found + gopath = os.environ.get("GOPATH") or str(Path.home() / "go") + for bin_dir in ("bin", str(Path("packages") / "bin")): + candidate = Path(gopath) / bin_dir / name + if candidate.is_file() and os.access(candidate, os.X_OK): + return str(candidate) + return None + + +def _run_formatter(tool: str, source: str) -> str | None: + try: + result = subprocess.run([tool], input=source, capture_output=True, text=True, timeout=15, check=False) + if result.returncode == 0: + return result.stdout + logger.debug("%s failed: %s", tool, result.stderr) + except subprocess.TimeoutExpired: + logger.warning("%s timed out", tool) + except Exception: + logger.debug("%s error", tool, exc_info=True) + return None + + +def normalize_go_code(source: str) -> str: + lines = source.splitlines() + normalized: list[str] = [] + in_block_comment = False + + for line in lines: + if in_block_comment: + if "*/" in line: + in_block_comment = False + line = line[line.index("*/") + 2 :] + else: + continue + + if "//" in line: + comment_pos = _find_line_comment_pos(line) + if comment_pos >= 0: + line = line[:comment_pos] + + if "/*" in line: + start_idx = line.index("/*") + if "*/" in line[start_idx:]: + end_idx = line.index("*/", start_idx) + line = line[:start_idx] + line[end_idx + 2 :] + else: + in_block_comment = True + line = line[:start_idx] + + stripped = line.strip() + if stripped: + normalized.append(stripped) + + return "\n".join(normalized) + + +def _find_line_comment_pos(line: str) -> int: + in_string = False + in_rune = False + escape_next = False + in_raw_string = False + + i = 0 + while i < len(line): + char = line[i] + + if escape_next: + escape_next = False + i += 1 + continue + + if in_raw_string: + if char == "`": + in_raw_string = False + i += 1 + continue + + if char == "`": + in_raw_string = True + i += 1 + continue + + if char == "\\": + escape_next = True + i += 1 + continue + + if char == '"' and not in_rune: + in_string = not in_string + elif char == "'" and not in_string: + in_rune = not in_rune + elif not in_string and not in_rune and i < len(line) - 1 and line[i : i + 2] == "//": + return i + + i += 1 + + return -1 diff --git a/codeflash/languages/golang/function_optimizer.py b/codeflash/languages/golang/function_optimizer.py new file mode 100644 index 000000000..0b679da3b --- /dev/null +++ b/codeflash/languages/golang/function_optimizer.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import hashlib +from collections import defaultdict +from typing import TYPE_CHECKING + +from codeflash.code_utils.code_utils import encoded_tokens_len +from codeflash.code_utils.config_consts import ( + OPTIMIZATION_CONTEXT_TOKEN_LIMIT, + READ_WRITABLE_LIMIT_ERROR, + TESTGEN_CONTEXT_TOKEN_LIMIT, + TESTGEN_LIMIT_ERROR, +) +from codeflash.either import Failure, Success +from codeflash.languages.function_optimizer import FunctionOptimizer +from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource +from codeflash.verification.equivalence import compare_test_results + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash.either import Result + from codeflash.languages.base import CodeContext, HelperFunction + from codeflash.models.models import OriginalCodeBaseline, TestDiff, TestResults + + +class GoFunctionOptimizer(FunctionOptimizer): + def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: + from codeflash.languages import get_language_support + from codeflash.languages.language_enum import Language + + language = Language(self.function_to_optimize.language) + lang_support = get_language_support(language) + + try: + code_context = lang_support.extract_code_context( + self.function_to_optimize, self.project_root, self.project_root + ) + return Success( + _build_optimization_context( + code_context, + self.function_to_optimize.file_path, + self.function_to_optimize.language, + self.project_root, + ) + ) + except ValueError as e: + return Failure(str(e)) + + def compare_candidate_results( + self, + baseline_results: OriginalCodeBaseline, + candidate_behavior_results: TestResults, + optimization_candidate_index: int, + ) -> tuple[bool, list[TestDiff]]: + return compare_test_results( + baseline_results.behavior_test_results, candidate_behavior_results, pass_fail_only=True + ) + + def replace_function_and_helpers_with_optimized_code( + self, + code_context: CodeOptimizationContext, + optimized_code: CodeStringsMarkdown, + original_helper_code: dict[Path, str], + ) -> bool: + from codeflash.languages.code_replacer import replace_function_definitions_for_language + from codeflash.languages.golang.formatter import format_go_code + + did_update = False + modified_files: list[Path] = [] + for module_abspath, qualified_names in self.group_functions_by_file(code_context).items(): + updated = replace_function_definitions_for_language( + function_names=list(qualified_names), + optimized_code=optimized_code, + module_abspath=module_abspath, + project_root_path=self.project_root, + lang_support=self.language_support, + function_to_optimize=self.function_to_optimize, + ) + if updated: + modified_files.append(module_abspath) + did_update |= updated + + for file_path in modified_files: + source = file_path.read_text(encoding="utf-8") + formatted = format_go_code(source, file_path) + if formatted != source: + file_path.write_text(formatted, encoding="utf-8") + + return did_update + + +def _extract_package_name(file_path: Path) -> str | None: + from codeflash.languages.golang.parser import GoAnalyzer + + try: + source = file_path.read_text(encoding="utf-8") + except OSError: + return None + return GoAnalyzer().find_package_name(source) + + +def _build_optimization_context( + code_context: CodeContext, + file_path: Path, + language: str, + project_root: Path, + optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT, + testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT, +) -> CodeOptimizationContext: + package_name = _extract_package_name(file_path) + + if code_context.imports: + inner = "\n".join(f"\t{imp}" for imp in code_context.imports) + imports_code = f"import (\n{inner}\n)" + else: + imports_code = "" + + try: + target_relative_path = file_path.resolve().relative_to(project_root.resolve()) + except ValueError: + target_relative_path = file_path + + helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list) + helper_function_sources = [] + + for helper in code_context.helper_functions: + helpers_by_file[helper.file_path].append(helper) + helper_function_sources.append( + FunctionSource( + file_path=helper.file_path, + qualified_name=helper.qualified_name, + fully_qualified_name=helper.qualified_name, + only_function_name=helper.name, + source_code=helper.source_code, + ) + ) + + target_file_code = code_context.target_code + same_file_helpers = helpers_by_file.get(file_path, []) + if same_file_helpers: + helper_code = "\n\n".join(h.source_code for h in same_file_helpers) + target_file_code = target_file_code + "\n\n" + helper_code + + if imports_code: + target_file_code = imports_code + "\n\n" + target_file_code + + if package_name: + target_file_code = f"package {package_name}\n\n" + target_file_code + + read_writable_code_strings = [CodeString(code=target_file_code, file_path=target_relative_path, language=language)] + + for helper_file_path, file_helpers in helpers_by_file.items(): + if helper_file_path == file_path: + continue + try: + helper_relative_path = helper_file_path.resolve().relative_to(project_root.resolve()) + except ValueError: + helper_relative_path = helper_file_path + combined_helper_code = "\n\n".join(h.source_code for h in file_helpers) + read_writable_code_strings.append( + CodeString(code=combined_helper_code, file_path=helper_relative_path, language=language) + ) + + read_writable_code = CodeStringsMarkdown(code_strings=read_writable_code_strings, language=language) + testgen_context = CodeStringsMarkdown(code_strings=read_writable_code_strings.copy(), language=language) + + read_writable_tokens = encoded_tokens_len(read_writable_code.markdown) + if read_writable_tokens > optim_token_limit: + raise ValueError(READ_WRITABLE_LIMIT_ERROR) + + testgen_tokens = encoded_tokens_len(testgen_context.markdown) + if testgen_tokens > testgen_token_limit: + raise ValueError(TESTGEN_LIMIT_ERROR) + + code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest() + + return CodeOptimizationContext( + testgen_context=testgen_context, + read_writable_code=read_writable_code, + read_only_context_code=code_context.read_only_context, + hashing_code_context=read_writable_code.flat, + hashing_code_context_hash=code_hash, + helper_functions=helper_function_sources, + testgen_helper_fqns=[fs.fully_qualified_name for fs in helper_function_sources], + preexisting_objects=set(), + ) diff --git a/codeflash/languages/golang/instrumentation.py b/codeflash/languages/golang/instrumentation.py new file mode 100644 index 000000000..122d9e6c8 --- /dev/null +++ b/codeflash/languages/golang/instrumentation.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import re + +_FUNC_BODY_RE = re.compile(r"^(func\s+)(Test\w+)(\s*\(\s*)(\w+)(\s+\*testing\.T\s*\)\s*\{)", re.MULTILINE) +_PARALLEL_RE = re.compile(r"^\s*\w+\.Parallel\(\)\s*\n?", re.MULTILINE) +_HELPER_RE = re.compile(r"^\s*\w+\.Helper\(\)\s*\n?", re.MULTILINE) + + +def _test_matches_target(test_name: str, target_function_name: str) -> bool: + remainder = test_name[len("Test") :] + segments = remainder.split("_") + return target_function_name in segments + + +def convert_tests_to_benchmarks(test_source: str, target_function_name: str = "") -> str: + if not test_source.strip(): + return test_source + + if not _FUNC_BODY_RE.search(test_source): + return test_source + + result = test_source + + for match in reversed(list(_FUNC_BODY_RE.finditer(result))): + func_prefix = match.group(1) + test_name = match.group(2) + paren_open = match.group(3) + param_name = match.group(4) + + body_start = match.end() + brace_depth = 1 + pos = body_start + while pos < len(result) and brace_depth > 0: + if result[pos] == "{": + brace_depth += 1 + elif result[pos] == "}": + brace_depth -= 1 + pos += 1 + + if target_function_name and not _test_matches_target(test_name, target_function_name): + result = result[: match.start()] + result[pos:] + continue + + body = result[body_start : pos - 1] + bench_name = "Benchmark" + test_name[len("Test") :] + + new_sig = f"{func_prefix}{bench_name}{paren_open}{param_name} *testing.B) {{\n\tfor i := 0; i < {param_name}.N; i++ {{" + new_func = f"{new_sig}{body}\t}}\n}}" + result = result[: match.start()] + new_func + result[pos:] + + result = result.replace("*testing.T", "*testing.B") + result = _PARALLEL_RE.sub("", result) + return _HELPER_RE.sub("", result) diff --git a/codeflash/languages/golang/parse.py b/codeflash/languages/golang/parse.py new file mode 100644 index 000000000..e6c0f09e7 --- /dev/null +++ b/codeflash/languages/golang/parse.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +import json +import logging +import re +from typing import TYPE_CHECKING, Any + +from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults + +if TYPE_CHECKING: + import subprocess + from pathlib import Path + + from codeflash.models.models import TestFiles + from codeflash.models.test_type import TestType + from codeflash.verification.verification_utils import TestConfig + +logger = logging.getLogger(__name__) + +BENCHMARK_RE = re.compile( + r"^(Benchmark\w+)(?:-\d+)?\s+" + r"(\d+)\s+" + r"([\d.]+)\s+ns/op" + r"(?:\s+(\d+)\s+B/op)?" + r"(?:\s+(\d+)\s+allocs/op)?" +) + + +def parse_go_test_output( + test_json_path: Path, + test_files: TestFiles, + test_config: TestConfig, + run_result: subprocess.CompletedProcess[str] | None = None, +) -> TestResults: + test_results = TestResults() + + content = _read_json_output(test_json_path, run_result) + if not content: + logger.warning("No Go test output to parse from %s", test_json_path) + return test_results + + events = _parse_json_lines(content) + if not events: + logger.warning("No valid JSON events found in %s", test_json_path) + return test_results + + iterations: list[_TestIteration] = [] + active: dict[str, _TestIteration] = {} + + for event in events: + action = event.get("Action") + test_name = event.get("Test") + package = event.get("Package", "") + + if test_name is None: + if action == "output": + output_text = event.get("Output", "") + bench_match = BENCHMARK_RE.search(output_text) + if bench_match: + bench_name = bench_match.group(1) + it = _TestIteration(test_name=bench_name, package=package) + it.passed = True + it.bench_ns_per_op = float(bench_match.group(3)) + it.bench_iterations = int(bench_match.group(2)) + it.stdout = output_text + iterations.append(it) + continue + + if action == "run": + if test_name in active: + iterations.append(active[test_name]) + active[test_name] = _TestIteration(test_name=test_name, package=package) + continue + + maybe_it = active.get(test_name) + if maybe_it is None: + maybe_it = _TestIteration(test_name=test_name, package=package) + active[test_name] = maybe_it + it = maybe_it + + if action == "output": + output_text = event.get("Output", "") + it.stdout += output_text + bench_match = BENCHMARK_RE.search(output_text) + if bench_match: + it.bench_ns_per_op = float(bench_match.group(3)) + it.bench_iterations = int(bench_match.group(2)) + elif action in ("pass", "fail"): + it.passed = action == "pass" + elapsed = event.get("Elapsed", 0) + it.elapsed_ns = int(elapsed * 1_000_000_000) if elapsed else None + iterations.append(active.pop(test_name)) + + for it in active.values(): + if it.passed is not None: + iterations.append(it) + + loop_counters: dict[str, int] = {} + base_dir = test_config.tests_project_rootdir + + for it in iterations: + if it.passed is None: + continue + + loop_index = loop_counters.get(it.test_name, 0) + 1 + loop_counters[it.test_name] = loop_index + + runtime_ns = it.bench_ns_per_op if it.bench_ns_per_op is not None else it.elapsed_ns + if runtime_ns is not None: + runtime_ns = int(runtime_ns) + + test_file_path = _resolve_test_file(it.test_name, it.package, test_files, base_dir) + test_type = _resolve_test_type(test_file_path, test_files) + + test_results.add( + FunctionTestInvocation( + loop_index=loop_index, + id=InvocationId( + test_module_path=it.package, + test_class_name=None, + test_function_name=it.test_name, + function_getting_tested="", + iteration_id="", + ), + file_name=test_file_path, + runtime=runtime_ns, + test_framework="go-test", + did_pass=it.passed, + test_type=test_type, + return_value=None, + timed_out=False, + stdout=it.stdout, + ) + ) + + if not test_results: + logger.info("No Go test results parsed from %s", test_json_path) + if run_result is not None: + logger.debug("stdout: %s\nstderr: %s", run_result.stdout, run_result.stderr) + + logger.debug("[BENCHMARK-DONE] Got %d benchmark results", len(test_results)) + + return test_results + + +class _TestIteration: + __slots__ = ("bench_iterations", "bench_ns_per_op", "elapsed_ns", "package", "passed", "stdout", "test_name") + + def __init__(self, test_name: str, package: str) -> None: + self.test_name = test_name + self.package = package + self.passed: bool | None = None + self.elapsed_ns: int | None = None + self.bench_ns_per_op: float | None = None + self.bench_iterations: int | None = None + self.stdout: str = "" + + +def _read_json_output(path: Path, run_result: subprocess.CompletedProcess[str] | None) -> str: + try: + content = path.read_text(encoding="utf-8") + if content.strip(): + return content + except Exception: + pass + if run_result is not None: + return run_result.stdout or "" + return "" + + +def _parse_json_lines(content: str) -> list[dict[str, Any]]: + events: list[dict[str, Any]] = [] + for line in content.splitlines(): + line = line.strip() + if not line: + continue + try: + events.append(json.loads(line)) + except json.JSONDecodeError: + continue + return events + + +def _resolve_test_file(test_name: str, package: str, test_files: TestFiles, base_dir: Path) -> Path: + + for tf in test_files.test_files: + behavior_path = tf.instrumented_behavior_file_path + if behavior_path.exists(): + return behavior_path + if tf.original_file_path and tf.original_file_path.exists(): + return tf.original_file_path + + if package: + return base_dir / package.replace("/", "_") + return base_dir / f"{test_name}.go" + + +def _resolve_test_type(test_file_path: Path, test_files: TestFiles) -> TestType: + from codeflash.models.test_type import TestType + + test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path) + if test_type is not None: + return test_type + test_type = test_files.get_test_type_by_original_file_path(test_file_path) + if test_type is not None: + return test_type + if test_files.test_files: + return test_files.test_files[0].test_type + return TestType.GENERATED_REGRESSION diff --git a/codeflash/languages/golang/parser.py b/codeflash/languages/golang/parser.py new file mode 100644 index 000000000..caa3dbb44 --- /dev/null +++ b/codeflash/languages/golang/parser.py @@ -0,0 +1,422 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from tree_sitter import Language, Parser + +if TYPE_CHECKING: + from tree_sitter import Node, Tree + +logger = logging.getLogger(__name__) + +_GO_LANGUAGE: Language | None = None +_GO_PARSER: Parser | None = None + + +def _get_go_language() -> Language: + global _GO_LANGUAGE + if _GO_LANGUAGE is None: + import tree_sitter_go + + _GO_LANGUAGE = Language(tree_sitter_go.language()) + return _GO_LANGUAGE + + +def _get_go_parser() -> Parser: + global _GO_PARSER + if _GO_PARSER is None: + _GO_PARSER = Parser(_get_go_language()) + return _GO_PARSER + + +@dataclass(frozen=True) +class GoFunctionNode: + name: str + starting_line: int + ending_line: int + starting_col: int + ending_col: int + is_exported: bool + has_return_type: bool + doc_start_line: int | None = None + + +@dataclass(frozen=True) +class GoMethodNode: + name: str + receiver_name: str + receiver_is_pointer: bool + starting_line: int + ending_line: int + starting_col: int + ending_col: int + is_exported: bool + has_return_type: bool + doc_start_line: int | None = None + + +@dataclass(frozen=True) +class GoStructNode: + name: str + starting_line: int + ending_line: int + fields: list[str] = field(default_factory=list) + + +@dataclass(frozen=True) +class GoInterfaceNode: + name: str + starting_line: int + ending_line: int + methods: list[str] = field(default_factory=list) + + +@dataclass(frozen=True) +class GoImportInfo: + path: str + alias: str | None + starting_line: int + ending_line: int + + +@dataclass(frozen=True) +class GoGlobalDeclaration: + names: tuple[str, ...] + kind: str + source_code: str + starting_line: int + ending_line: int + + +class GoAnalyzer: + def __init__(self) -> None: + self._parser = _get_go_parser() + self._source_bytes: bytes | None = None + self._tree: Tree | None = None + + @property + def last_tree(self) -> Tree | None: + return self._tree + + def parse(self, source: str) -> Tree: + self._source_bytes = source.encode("utf-8") + self._tree = self._parser.parse(self._source_bytes) + return self._tree + + def get_node_text(self, node: Node) -> str: + if self._source_bytes is None: + return "" + return self._source_bytes[node.start_byte : node.end_byte].decode("utf-8") + + def validate_syntax(self, source: str) -> bool: + tree = self.parse(source) + return not tree.root_node.has_error + + def find_functions(self, source: str) -> list[GoFunctionNode]: + tree = self.parse(source) + results: list[GoFunctionNode] = [] + for node in tree.root_node.children: + if node.type == "function_declaration": + func = self._parse_function_node(node) + if func is not None: + results.append(func) + return results + + def find_methods(self, source: str) -> list[GoMethodNode]: + tree = self.parse(source) + results: list[GoMethodNode] = [] + for node in tree.root_node.children: + if node.type == "method_declaration": + method = self._parse_method_node(node) + if method is not None: + results.append(method) + return results + + def find_structs(self, source: str) -> list[GoStructNode]: + tree = self.parse(source) + results: list[GoStructNode] = [] + for node in tree.root_node.children: + if node.type == "type_declaration": + for spec in _children_of_type(node, "type_spec"): + type_node = spec.child_by_field_name("type") + if type_node is not None and type_node.type == "struct_type": + name_node = spec.child_by_field_name("name") + if name_node is not None: + fields = self._extract_struct_fields(type_node) + results.append( + GoStructNode( + name=self.get_node_text(name_node), + starting_line=node.start_point.row + 1, + ending_line=node.end_point.row + 1, + fields=fields, + ) + ) + return results + + def find_interfaces(self, source: str) -> list[GoInterfaceNode]: + tree = self.parse(source) + results: list[GoInterfaceNode] = [] + for node in tree.root_node.children: + if node.type == "type_declaration": + for spec in _children_of_type(node, "type_spec"): + type_node = spec.child_by_field_name("type") + if type_node is not None and type_node.type == "interface_type": + name_node = spec.child_by_field_name("name") + if name_node is not None: + methods = self._extract_interface_methods(type_node) + results.append( + GoInterfaceNode( + name=self.get_node_text(name_node), + starting_line=node.start_point.row + 1, + ending_line=node.end_point.row + 1, + methods=methods, + ) + ) + return results + + def find_imports(self, source: str) -> list[GoImportInfo]: + tree = self.parse(source) + results: list[GoImportInfo] = [] + for node in tree.root_node.children: + if node.type == "import_declaration": + for spec in _iter_import_specs(node): + path_node = spec.child_by_field_name("path") + if path_node is None: + continue + import_path = self.get_node_text(path_node).strip('"') + alias_node = spec.child_by_field_name("name") + alias = self.get_node_text(alias_node) if alias_node is not None else None + results.append( + GoImportInfo( + path=import_path, + alias=alias, + starting_line=spec.start_point.row + 1, + ending_line=spec.end_point.row + 1, + ) + ) + return results + + def find_global_declarations(self, source: str) -> list[GoGlobalDeclaration]: + tree = self.parse(source) + results: list[GoGlobalDeclaration] = [] + for node in tree.root_node.children: + if node.type in ("var_declaration", "const_declaration"): + kind = "var" if node.type == "var_declaration" else "const" + names = _extract_declaration_names(node, self) + if names: + results.append( + GoGlobalDeclaration( + names=tuple(names), + kind=kind, + source_code=self.get_node_text(node), + starting_line=node.start_point.row + 1, + ending_line=node.end_point.row + 1, + ) + ) + return results + + def collect_body_identifiers(self, source: str, func_name: str, receiver_type: str | None = None) -> set[str]: + tree = self.parse(source) + for node in tree.root_node.children: + if receiver_type is None and node.type == "function_declaration": + name_node = node.child_by_field_name("name") + if name_node is not None and self.get_node_text(name_node) == func_name: + body = node.child_by_field_name("body") + return _collect_identifiers(body) if body else set() + if receiver_type is not None and node.type == "method_declaration": + name_node = node.child_by_field_name("name") + if name_node is None or self.get_node_text(name_node) != func_name: + continue + recv_node = node.child_by_field_name("receiver") + if recv_node is not None: + recv_name, _ = self.parse_receiver(recv_node) + if recv_name == receiver_type: + body = node.child_by_field_name("body") + return _collect_identifiers(body) if body else set() + return set() + + def find_package_name(self, source: str) -> str | None: + tree = self.parse(source) + for node in tree.root_node.children: + if node.type == "package_clause": + for child in node.children: + if child.type == "package_identifier": + return self.get_node_text(child) + return None + + def _parse_function_node(self, node: Node) -> GoFunctionNode | None: + name_node = node.child_by_field_name("name") + if name_node is None: + return None + name = self.get_node_text(name_node) + result_node = node.child_by_field_name("result") + doc_line = _find_preceding_comment_line(node) + return GoFunctionNode( + name=name, + starting_line=node.start_point.row + 1, + ending_line=node.end_point.row + 1, + starting_col=node.start_point.column, + ending_col=node.end_point.column, + is_exported=name[0].isupper(), + has_return_type=result_node is not None, + doc_start_line=doc_line, + ) + + def _parse_method_node(self, node: Node) -> GoMethodNode | None: + name_node = node.child_by_field_name("name") + if name_node is None: + return None + name = self.get_node_text(name_node) + + receiver_node = node.child_by_field_name("receiver") + if receiver_node is None: + return None + receiver_name, receiver_is_pointer = self.parse_receiver(receiver_node) + if receiver_name is None: + return None + + result_node = node.child_by_field_name("result") + doc_line = _find_preceding_comment_line(node) + return GoMethodNode( + name=name, + receiver_name=receiver_name, + receiver_is_pointer=receiver_is_pointer, + starting_line=node.start_point.row + 1, + ending_line=node.end_point.row + 1, + starting_col=node.start_point.column, + ending_col=node.end_point.column, + is_exported=name[0].isupper(), + has_return_type=result_node is not None, + doc_start_line=doc_line, + ) + + def parse_receiver(self, receiver_node: Node) -> tuple[str | None, bool]: + for param in _children_of_type(receiver_node, "parameter_declaration"): + type_node = param.child_by_field_name("type") + if type_node is None: + continue + if type_node.type == "pointer_type": + inner = type_node.child(1) + if inner is not None: + return self.get_node_text(inner), True + elif type_node.type == "type_identifier": + return self.get_node_text(type_node), False + return None, False + + def _extract_struct_fields(self, struct_node: Node) -> list[str]: + fields: list[str] = [] + for child in struct_node.children: + if child.type == "field_declaration_list": + for fc in child.children: + if fc.type == "field_declaration": + fields.append(self.get_node_text(fc).strip()) + break + return fields + + def _extract_interface_methods(self, iface_node: Node) -> list[str]: + methods: list[str] = [] + for child in iface_node.children: + if child.type == "method_elem": + methods.append(self.get_node_text(child).strip()) + return methods + + def extract_function_source(self, source: str, func_name: str, receiver_type: str | None = None) -> str | None: + tree = self.parse(source) + for node in tree.root_node.children: + if receiver_type is None and node.type == "function_declaration": + name_node = node.child_by_field_name("name") + if name_node is not None and self.get_node_text(name_node) == func_name: + return self._get_source_with_doc(node) + + if receiver_type is not None and node.type == "method_declaration": + name_node = node.child_by_field_name("name") + if name_node is None or self.get_node_text(name_node) != func_name: + continue + recv_node = node.child_by_field_name("receiver") + if recv_node is not None: + recv_name, _ = self.parse_receiver(recv_node) + if recv_name == receiver_type: + return self._get_source_with_doc(node) + return None + + def _get_source_with_doc(self, node: Node) -> str: + doc_line = _find_preceding_comment_line(node) + if doc_line is not None and self._source_bytes is not None: + lines = self._source_bytes.decode("utf-8").splitlines(keepends=True) + start = doc_line - 1 + end = node.end_point.row + 1 + return "".join(lines[start:end]) + return self.get_node_text(node) + + +def _children_of_type(node: Node, type_name: str) -> list[Node]: + return [child for child in node.children if child.type == type_name] + + +def _iter_import_specs(import_node: Node) -> list[Node]: + results: list[Node] = [] + for child in import_node.children: + if child.type == "import_spec": + results.append(child) + elif child.type == "import_spec_list": + results.extend(c for c in child.children if c.type == "import_spec") + return results + + +def _extract_declaration_names(node: Node, analyzer: GoAnalyzer) -> list[str]: + names: list[str] = [] + for child in node.children: + if child.type in ("var_spec", "const_spec"): + name_node = child.child_by_field_name("name") + if name_node is not None: + names.append(analyzer.get_node_text(name_node)) + elif child.type in ("var_spec_list", "const_spec_list"): + for spec in child.children: + if spec.type in ("var_spec", "const_spec"): + name_node = spec.child_by_field_name("name") + if name_node is not None: + names.append(analyzer.get_node_text(name_node)) + return names + + +def _collect_identifiers(node: Node | None) -> set[str]: + if node is None: + return set() + ids: set[str] = set() + stack = [node] + while stack: + n = stack.pop() + if n.type in ("identifier", "type_identifier", "field_identifier"): + text = n.parent + if text is not None and text.type not in ("parameter_declaration", "short_var_declaration"): + ids.add(n.text.decode("utf-8") if n.text else "") + elif text is not None and text.type == "short_var_declaration": + name_node = text.child_by_field_name("left") + if name_node is not n and (name_node is None or n not in (name_node, *tuple(name_node.children))): + ids.add(n.text.decode("utf-8") if n.text else "") + stack.extend(n.children) + ids.discard("") + return ids + + +def _find_preceding_comment_line(node: Node) -> int | None: + prev = node.prev_named_sibling + if prev is None: + return None + if prev.type != "comment": + return None + if prev.end_point.row + 1 != node.start_point.row: + return None + comment_start = prev.start_point.row + 1 + current = prev + while True: + earlier = current.prev_named_sibling + if earlier is None or earlier.type != "comment": + break + if earlier.end_point.row + 1 != current.start_point.row: + break + comment_start = earlier.start_point.row + 1 + current = earlier + return comment_start diff --git a/codeflash/languages/golang/replacement.py b/codeflash/languages/golang/replacement.py new file mode 100644 index 000000000..c6301312f --- /dev/null +++ b/codeflash/languages/golang/replacement.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash.languages.golang.parser import GoAnalyzer + +if TYPE_CHECKING: + import tree_sitter + + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.languages.golang.parser import GoGlobalDeclaration + +logger = logging.getLogger(__name__) + + +def replace_function( + source: str, function: FunctionToOptimize, new_source: str, analyzer: GoAnalyzer | None = None +) -> str: + analyzer = analyzer or GoAnalyzer() + receiver_type = function.parents[0].name if function.parents else None + + tree = analyzer.parse(source) + target_node = None + + for node in tree.root_node.children: + if receiver_type is None and node.type == "function_declaration": + name_node = node.child_by_field_name("name") + if name_node is not None and analyzer.get_node_text(name_node) == function.function_name: + target_node = node + break + elif receiver_type is not None and node.type == "method_declaration": + name_node = node.child_by_field_name("name") + if name_node is None or analyzer.get_node_text(name_node) != function.function_name: + continue + recv_node = node.child_by_field_name("receiver") + if recv_node is not None: + recv_name, _ = analyzer.parse_receiver(recv_node) + if recv_name == receiver_type: + target_node = node + break + + if target_node is None: + logger.warning("Could not find function %s in source for replacement", function.function_name) + return source + + lines = source.splitlines(keepends=True) + doc_line = _find_doc_comment_start(target_node) + start_line = (doc_line if doc_line is not None else target_node.start_point.row + 1) - 1 + end_line = target_node.end_point.row + 1 + + new_source_stripped = new_source.rstrip("\n") + "\n" + + result_lines = [*lines[:start_line], new_source_stripped, *lines[end_line:]] + return "".join(result_lines) + + +def add_global_declarations(optimized_code: str, original_source: str, analyzer: GoAnalyzer | None = None) -> str: + analyzer = analyzer or GoAnalyzer() + + merged = _merge_imports(optimized_code, original_source, analyzer) + return _merge_global_var_const(optimized_code, merged, analyzer) + + +def _merge_imports(optimized_code: str, original_source: str, analyzer: GoAnalyzer) -> str: + opt_imports = analyzer.find_imports(optimized_code) + orig_imports = analyzer.find_imports(original_source) + orig_paths = {imp.path for imp in orig_imports} + + new_imports = [imp for imp in opt_imports if imp.path not in orig_paths] + if not new_imports: + return original_source + + lines = original_source.splitlines(keepends=True) + + import_block_end = _find_import_block_end(original_source, analyzer) + + new_import_lines = [] + for imp in new_imports: + if imp.alias: + new_import_lines.append(f'\t{imp.alias} "{imp.path}"\n') + else: + new_import_lines.append(f'\t"{imp.path}"\n') + + if orig_imports: + last_import = max(orig_imports, key=lambda i: i.ending_line) + insert_at = last_import.ending_line + for node in analyzer.last_tree.root_node.children if analyzer.last_tree else []: + if node.type == "import_declaration": + for child in node.children: + if child.type == "import_spec_list": + close_paren_line = child.end_point.row + insert_at = close_paren_line + break + return "".join([*lines[:insert_at], *new_import_lines, *lines[insert_at:]]) + + insert_at = import_block_end + import_block = "import (\n" + "".join(new_import_lines) + ")\n\n" + return "".join([*lines[:insert_at], import_block, *lines[insert_at:]]) + + +def _merge_global_var_const(optimized_code: str, original_source: str, analyzer: GoAnalyzer) -> str: + opt_decls = analyzer.find_global_declarations(optimized_code) + if not opt_decls: + return original_source + + orig_decls = analyzer.find_global_declarations(original_source) + orig_names_to_decl: dict[str, GoGlobalDeclaration] = {} + for decl in orig_decls: + for name in decl.names: + orig_names_to_decl[name] = decl + + new_decls: list[str] = [] + replaced_decls: set[int] = set() + + for opt_decl in opt_decls: + overlapping_orig = None + for name in opt_decl.names: + if name in orig_names_to_decl: + overlapping_orig = orig_names_to_decl[name] + break + + if overlapping_orig is None: + new_decls.append(opt_decl.source_code) + elif overlapping_orig.source_code.strip() != opt_decl.source_code.strip(): + orig_id = id(overlapping_orig) + if orig_id not in replaced_decls: + replaced_decls.add(orig_id) + original_source = _replace_declaration_block(original_source, overlapping_orig, opt_decl.source_code) + + if new_decls: + original_source = _insert_new_declarations(original_source, new_decls, analyzer) + + return original_source + + +def _replace_declaration_block(source: str, orig_decl: GoGlobalDeclaration, new_source_code: str) -> str: + lines = source.splitlines(keepends=True) + start = orig_decl.starting_line - 1 + end = orig_decl.ending_line + replacement = new_source_code.rstrip("\n") + "\n" + return "".join([*lines[:start], replacement, *lines[end:]]) + + +def _insert_new_declarations(source: str, new_decls: list[str], analyzer: GoAnalyzer) -> str: + lines = source.splitlines(keepends=True) + + insert_at = _find_declarations_insert_point(source, analyzer) + + block = "\n".join(new_decls) + "\n\n" + return "".join([*lines[:insert_at], block, *lines[insert_at:]]) + + +def _find_declarations_insert_point(source: str, analyzer: GoAnalyzer) -> int: + tree = analyzer.parse(source) + last_line = 0 + for node in tree.root_node.children: + if node.type in ("import_declaration", "var_declaration", "const_declaration"): + candidate = node.end_point.row + 1 + last_line = max(last_line, candidate) + if last_line > 0: + return last_line + + for node in tree.root_node.children: + if node.type == "package_clause": + return node.end_point.row + 1 + return 0 + + +def remove_test_functions(test_source: str, functions_to_remove: list[str], analyzer: GoAnalyzer | None = None) -> str: + analyzer = analyzer or GoAnalyzer() + tree = analyzer.parse(test_source) + lines = test_source.splitlines(keepends=True) + + regions_to_remove: list[tuple[int, int]] = [] + + for node in tree.root_node.children: + if node.type == "function_declaration": + name_node = node.child_by_field_name("name") + if name_node is not None and analyzer.get_node_text(name_node) in functions_to_remove: + doc_start = _find_doc_comment_start(node) + start = (doc_start if doc_start is not None else node.start_point.row + 1) - 1 + end = node.end_point.row + 1 + regions_to_remove.append((start, end)) + + for start, end in reversed(regions_to_remove): + del lines[start:end] + + return "".join(lines) + + +def _find_doc_comment_start(node: tree_sitter.Node) -> int | None: + prev = node.prev_named_sibling + if prev is None: + return None + if prev.type != "comment": + return None + if prev.end_point.row + 1 != node.start_point.row: + return None + comment_start: int = prev.start_point.row + 1 + current = prev + while True: + earlier = current.prev_named_sibling + if earlier is None or earlier.type != "comment": + break + if earlier.end_point.row + 1 != current.start_point.row: + break + comment_start = earlier.start_point.row + 1 + current = earlier + return comment_start + + +def _find_import_block_end(source: str, analyzer: GoAnalyzer) -> int: + tree = analyzer.parse(source) + for node in tree.root_node.children: + if node.type == "package_clause": + return node.end_point.row + 1 + return 0 diff --git a/codeflash/languages/golang/support.py b/codeflash/languages/golang/support.py new file mode 100644 index 000000000..aaa7c8e3f --- /dev/null +++ b/codeflash/languages/golang/support.py @@ -0,0 +1,339 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from codeflash.languages.base import LanguageSupport +from codeflash.languages.golang.comparator import compare_test_results as _compare_results +from codeflash.languages.golang.config import detect_go_project, detect_go_version +from codeflash.languages.golang.context import extract_code_context as _extract_context +from codeflash.languages.golang.context import find_helper_functions as _find_helpers +from codeflash.languages.golang.discovery import discover_functions_from_source +from codeflash.languages.golang.formatter import format_go_code, normalize_go_code +from codeflash.languages.golang.parser import GoAnalyzer +from codeflash.languages.golang.replacement import add_global_declarations as _add_globals +from codeflash.languages.golang.replacement import remove_test_functions as _remove_tests +from codeflash.languages.golang.replacement import replace_function as _replace_func +from codeflash.languages.golang.test_discovery import discover_tests as _discover_tests +from codeflash.languages.golang.test_runner import parse_test_results as _parse_results +from codeflash.languages.golang.test_runner import run_behavioral_tests as _run_behavioral +from codeflash.languages.golang.test_runner import run_benchmarking_tests as _run_benchmarking +from codeflash.languages.language_enum import Language +from codeflash.languages.registry import register_language + +if TYPE_CHECKING: + from collections.abc import Sequence + + from codeflash.languages.base import ( + CodeContext, + DependencyResolver, + FunctionFilterCriteria, + HelperFunction, + ReferenceInfo, + TestInfo, + ) + from codeflash.models.function_types import FunctionToOptimize + from codeflash.models.models import GeneratedTestsList, InvocationId + +logger = logging.getLogger(__name__) + + +@register_language +class GoSupport(LanguageSupport): + def __init__(self) -> None: + self._analyzer = GoAnalyzer() + self._go_version: str | None = None + self._go_version_detected = False + + @property + def language(self) -> Language: + return Language.GO + + @property + def file_extensions(self) -> tuple[str, ...]: + return (".go",) + + @property + def default_file_extension(self) -> str: + return ".go" + + @property + def test_framework(self) -> str: + return "go-test" + + @property + def comment_prefix(self) -> str: + return "//" + + @property + def dir_excludes(self) -> frozenset[str]: + return frozenset({"vendor", "testdata", ".git", "node_modules"}) + + @property + def language_version(self) -> str | None: + if not self._go_version_detected: + self._go_version = detect_go_version() + self._go_version_detected = True + return self._go_version + + @property + def valid_test_frameworks(self) -> tuple[str, ...]: + return ("go-test",) + + @property + def test_result_serialization_format(self) -> str: + return "json" + + @property + def function_optimizer_class(self) -> type: + from codeflash.languages.golang.function_optimizer import GoFunctionOptimizer + + return GoFunctionOptimizer + + def discover_functions( + self, source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None + ) -> list[FunctionToOptimize]: + return discover_functions_from_source(source, file_path, filter_criteria, self._analyzer) + + def discover_tests( + self, test_root: Path, source_functions: Sequence[FunctionToOptimize] + ) -> dict[str, list[TestInfo]]: + return _discover_tests(test_root, source_functions) + + def validate_syntax(self, source: str, file_path: Path | None = None) -> bool: + return self._analyzer.validate_syntax(source) + + def parse_test_xml( + self, test_xml_file_path: Path, test_files: Any, test_config: Any, run_result: Any = None + ) -> Any: + from codeflash.languages.golang.parse import parse_go_test_output + + return parse_go_test_output(test_xml_file_path, test_files, test_config, run_result) + + def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext: + return _extract_context(function, project_root, module_root, self._analyzer) + + def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]: + try: + source = function.file_path.read_text(encoding="utf-8") + except Exception: + return [] + return _find_helpers(source, function, self._analyzer) + + def find_references( + self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 100 + ) -> list[ReferenceInfo]: + return [] + + def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str: + return _replace_func(source, function, new_source, self._analyzer) + + def format_code(self, source: str, file_path: Path | None = None) -> str: + return format_go_code(source, file_path) + + def normalize_code(self, source: str) -> str: + return normalize_go_code(source) + + def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str: + return _add_globals(optimized_code, original_source, self._analyzer) + + def get_module_path(self, source_file: Path, project_root: Path, tests_root: Path | None = None) -> str: + return str(source_file) + + def prepare_module( + self, module_code: str, module_path: Path, project_root: Path + ) -> tuple[dict[Path, Any], None] | None: + from codeflash.models.models import ValidCode + + if not self._analyzer.validate_syntax(module_code): + return None + validated: dict[Path, ValidCode] = { + module_path: ValidCode(source_code=module_code, normalized_code=normalize_go_code(module_code)) + } + return validated, None + + def setup_test_config(self, test_cfg: Any, file_path: Path, current_worktree: Path | None = None) -> bool: + _ = file_path, current_worktree + project_root = getattr(test_cfg, "project_root_path", Path.cwd()) + config = detect_go_project(project_root) + if config is not None and config.go_version: + self._go_version = config.go_version + self._go_version_detected = True + return True + + def detect_module_system(self, project_root: Path, source_file: Path | None = None) -> str | None: + return None + + def run_behavioral_tests( + self, + test_paths: Any, + test_env: dict[str, str], + cwd: Path, + timeout: int | None = None, + project_root: Path | None = None, + enable_coverage: bool = False, + candidate_index: int = 0, + ) -> tuple[Path, Any, Path | None, Path | None]: + return _run_behavioral(test_paths, test_env, cwd, timeout, project_root, enable_coverage, candidate_index) + + def run_benchmarking_tests( + self, + test_paths: Any, + test_env: dict[str, str], + cwd: Path, + timeout: int | None = None, + project_root: Path | None = None, + min_loops: int = 5, + max_loops: int = 100_000, + target_duration_seconds: float = 10.0, + inner_iterations: int = 100, + ) -> tuple[Path, Any]: + return _run_benchmarking( + test_paths, + test_env, + cwd, + timeout, + project_root, + min_loops, + max_loops, + target_duration_seconds, + inner_iterations, + ) + + def generate_concolic_tests(self, *args: Any, **kwargs: Any) -> tuple[dict[str, Any], str]: + return {}, "" + + def run_line_profile_tests(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + + def compare_test_results( + self, + original_results_path: Path, + candidate_results_path: Path, + project_root: Path | None = None, + project_classpath: str | None = None, + ) -> tuple[bool, list[Any]]: + return _compare_results(original_results_path, candidate_results_path, project_root, project_classpath) + + def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str: + return source + + def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str: + from codeflash.languages.golang.instrumentation import convert_tests_to_benchmarks + + func_name = target_function.function_name if target_function else "" + return convert_tests_to_benchmarks(test_source, func_name) + + def instrument_existing_test( + self, test_path: Path, call_positions: Any, function_to_optimize: Any, tests_project_root: Path, mode: str + ) -> tuple[bool, str | None]: + _ = call_positions, tests_project_root + try: + source = test_path.read_text(encoding="utf-8") + except Exception: + return False, None + if mode == "performance": + from codeflash.languages.golang.instrumentation import convert_tests_to_benchmarks + + func_name = function_to_optimize.function_name if function_to_optimize else "" + source = convert_tests_to_benchmarks(source, func_name) + return True, source + + def postprocess_generated_tests( + self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path + ) -> GeneratedTestsList: + _ = test_framework, project_root, source_file_path + return generated_tests + + def process_generated_test_strings( + self, + generated_test_source: str, + instrumented_behavior_test_source: str, + instrumented_perf_test_source: str, + function_to_optimize: Any, + test_path: Path, + test_cfg: Any, + project_module_system: str | None, + ) -> tuple[str, str, str]: + _ = test_path, test_cfg, project_module_system + from codeflash.languages.golang.instrumentation import convert_tests_to_benchmarks + + func_name = function_to_optimize.function_name if function_to_optimize else "" + instrumented_perf_test_source = convert_tests_to_benchmarks(instrumented_perf_test_source, func_name) + return generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source + + def load_coverage(self, *args: Any, **kwargs: Any) -> Any: + return None + + def get_test_file_suffix(self) -> str: + return "_test.go" + + def resolve_test_file_from_class_path(self, test_class_path: str, base_dir: Path) -> Path | None: + return None + + def resolve_test_module_path_for_pr( + self, test_module_path: str, tests_project_rootdir: Path, non_generated_tests: set[Path] + ) -> Path | None: + return None + + def find_test_root(self, project_root: Path) -> Path | None: + return project_root + + def get_runtime_files(self) -> list[Path]: + return [] + + def ensure_runtime_environment(self, project_root: Path) -> bool: + return detect_go_version() is not None + + def create_dependency_resolver(self, project_root: Path) -> DependencyResolver | None: + return None + + def adjust_test_config_for_discovery(self, test_cfg: Any) -> None: + pass + + def add_runtime_comments( + self, test_source: str, original_runtimes: dict[str, Any], optimized_runtimes: dict[str, Any] + ) -> str: + return test_source + + def remove_test_functions(self, test_source: str, functions_to_remove: list[str]) -> str: + return _remove_tests(test_source, functions_to_remove, self._analyzer) + + def add_runtime_comments_to_generated_tests( + self, + generated_tests: GeneratedTestsList, + original_runtimes: dict[InvocationId, list[int]], + optimized_runtimes: dict[InvocationId, list[int]], + tests_project_rootdir: Path | None = None, + ) -> GeneratedTestsList: + _ = original_runtimes, optimized_runtimes, tests_project_rootdir + return generated_tests + + def remove_test_functions_from_generated_tests( + self, generated_tests: GeneratedTestsList, functions_to_remove: list[str] + ) -> GeneratedTestsList: + from codeflash.models.models import GeneratedTests + + updated_tests: list[GeneratedTests] = [] + for test in generated_tests.generated_tests: + updated_tests.append( + GeneratedTests( + generated_original_test_source=self.remove_test_functions( + test.generated_original_test_source, functions_to_remove + ), + instrumented_behavior_test_source=test.instrumented_behavior_test_source, + instrumented_perf_test_source=test.instrumented_perf_test_source, + behavior_file_path=test.behavior_file_path, + perf_file_path=test.perf_file_path, + ) + ) + return type(generated_tests)(generated_tests=updated_tests) + + def get_test_dir_for_source(self, test_dir: Path, source_file: Path | None = None) -> Path | None: + if source_file is not None: + return source_file.parent + return test_dir + + def parse_test_results(self, json_output_path: Path, stdout: str) -> Any: + return _parse_results(json_output_path, stdout) diff --git a/codeflash/languages/golang/test_discovery.py b/codeflash/languages/golang/test_discovery.py new file mode 100644 index 000000000..d7ec039fc --- /dev/null +++ b/codeflash/languages/golang/test_discovery.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging +import re +from typing import TYPE_CHECKING + +from codeflash.languages.base import TestInfo + +if TYPE_CHECKING: + from collections.abc import Sequence + from pathlib import Path + + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + +logger = logging.getLogger(__name__) + +GO_TEST_FUNC_RE = re.compile(r"^func\s+(Test\w+)\s*\(", re.MULTILINE) + + +def discover_tests(test_root: Path, source_functions: Sequence[FunctionToOptimize]) -> dict[str, list[TestInfo]]: + func_name_to_qn: dict[str, list[str]] = {} + for func in source_functions: + func_name_to_qn.setdefault(func.function_name, []).append(func.qualified_name) + + test_files = list(test_root.rglob("*_test.go")) + result: dict[str, list[TestInfo]] = {} + + for test_file in test_files: + try: + content = test_file.read_text(encoding="utf-8") + except Exception: + logger.debug("Could not read test file %s", test_file) + continue + + test_func_names = GO_TEST_FUNC_RE.findall(content) + for test_func_name in test_func_names: + matched_qns = _match_test_to_functions(test_func_name, content, func_name_to_qn) + for qn in matched_qns: + info = TestInfo(test_name=test_func_name, test_file=test_file) + result.setdefault(qn, []).append(info) + + return result + + +def _match_test_to_functions(test_func_name: str, test_source: str, func_name_to_qn: dict[str, list[str]]) -> list[str]: + matched: list[str] = [] + + target_name = _extract_target_name(test_func_name) + if target_name and target_name in func_name_to_qn: + matched.extend(func_name_to_qn[target_name]) + return matched + + for func_name, qns in func_name_to_qn.items(): + if _test_calls_function(test_source, test_func_name, func_name): + matched.extend(qns) + + return matched + + +def _extract_target_name(test_func_name: str) -> str | None: + if not test_func_name.startswith("Test"): + return None + remainder = test_func_name[4:] + if not remainder: + return None + name = remainder.split("_")[0] + if not name: + return None + return name + + +def _test_calls_function(test_source: str, test_func_name: str, func_name: str) -> bool: + func_body = _extract_test_body(test_source, test_func_name) + if func_body is None: + return False + call_pattern = re.compile(rf"\b{re.escape(func_name)}\s*\(") + return call_pattern.search(func_body) is not None + + +def _extract_test_body(test_source: str, test_func_name: str) -> str | None: + pattern = re.compile(rf"func\s+{re.escape(test_func_name)}\s*\([^)]*\)\s*\{{") + match = pattern.search(test_source) + if match is None: + return None + + start = match.end() + depth = 1 + pos = start + while pos < len(test_source) and depth > 0: + ch = test_source[pos] + if ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + pos += 1 + + return test_source[start : pos - 1] if depth == 0 else None diff --git a/codeflash/languages/golang/test_runner.py b/codeflash/languages/golang/test_runner.py new file mode 100644 index 000000000..0accd4953 --- /dev/null +++ b/codeflash/languages/golang/test_runner.py @@ -0,0 +1,423 @@ +from __future__ import annotations + +import contextlib +import json +import logging +import os +import re +import signal +import subprocess +import sys +import time +from typing import TYPE_CHECKING, Any + +from codeflash.languages.base import TestResult + +if TYPE_CHECKING: + from collections.abc import Generator + from pathlib import Path + +logger = logging.getLogger(__name__) + + +def run_behavioral_tests( + test_paths: Any, + test_env: dict[str, str], + cwd: Path, + timeout: int | None = None, + project_root: Path | None = None, + enable_coverage: bool = False, + candidate_index: int = 0, +) -> tuple[Path, subprocess.CompletedProcess[str], Path | None, Path | None]: + result_dir = cwd / ".codeflash" / "go_test_results" + result_dir.mkdir(parents=True, exist_ok=True) + json_output_file = result_dir / f"behavioral_{candidate_index}.jsonl" + + test_file_paths = _collect_test_file_paths(test_paths) + packages = _test_files_to_packages(test_file_paths, cwd) + if not packages: + packages = ["./..."] + + env = {**os.environ, **test_env} + + others = _collect_other_test_files(test_file_paths) + with _hide_other_test_files(others), _deduplicated_test_files(test_file_paths): + run_regex = _build_run_regex(test_file_paths) + cmd = ["go", "test", "-json", "-v", "-count=1"] + if run_regex: + cmd.extend(["-run", run_regex]) + cmd.extend(packages) + proc_result = _run_cmd_kill_pg_on_timeout(cmd, cwd=cwd, env=env, timeout=timeout) + + json_output_file.write_text(proc_result.stdout or "", encoding="utf-8") + + return json_output_file, proc_result, None, None + + +def run_benchmarking_tests( + test_paths: Any, + test_env: dict[str, str], + cwd: Path, + timeout: int | None = None, + project_root: Path | None = None, + min_loops: int = 5, + max_loops: int = 100_000, + target_duration_seconds: float = 10.0, + inner_iterations: int = 100, +) -> tuple[Path, subprocess.CompletedProcess[str]]: + result_dir = cwd / ".codeflash" / "go_test_results" + result_dir.mkdir(parents=True, exist_ok=True) + json_output_file = result_dir / "benchmark.jsonl" + + test_file_paths = _collect_test_file_paths(test_paths, use_benchmarking=True) + packages = _test_files_to_packages(test_file_paths, cwd) + if not packages: + packages = ["./..."] + + env = {**os.environ, **test_env} + + others = _collect_other_test_files(test_file_paths) + with _hide_other_test_files(others), _deduplicated_test_files(test_file_paths): + bench_regex = _build_bench_regex(test_file_paths) + if bench_regex: + benchtime_secs = min(target_duration_seconds, 1.0) + num_benchmarks = len(_extract_func_names(test_file_paths, _BENCH_FUNC_RE)) + per_loop_estimate = int(num_benchmarks * benchtime_secs * 2) + 10 + cmd = [ + "go", + "test", + "-json", + "-v", + f"-bench={bench_regex}", + f"-benchtime={benchtime_secs:.0f}s", + # "-benchmem", + "-count=1", # setting count to as we looping manually to track timeout and max_loop + "-run=^$", + f"-timeout={per_loop_estimate}s", + *packages, + ] + # logger.info("Benchmark command: %s", cmd) + all_stdout: list[str] = [] + all_stderr: list[str] = [] + last_returncode = 0 + start_time = time.monotonic() + for loop in range(1, max_loops + 1): + proc_result = _run_cmd_kill_pg_on_timeout(cmd, cwd=cwd, env=env, timeout=per_loop_estimate) + if proc_result.stdout: + all_stdout.append(proc_result.stdout) + if proc_result.stderr: + all_stderr.append(proc_result.stderr) + last_returncode = proc_result.returncode + if proc_result.returncode != 0: + logger.warning( + "Benchmark loop %d failed (rc=%d):\nstdout:%s\nstderr: %s", + loop, + proc_result.returncode, + proc_result.stdout, + proc_result.stderr, + ) + break + elapsed = time.monotonic() - start_time + if loop >= min_loops and elapsed >= target_duration_seconds: + logger.info( + "Benchmark stopping after %d loops (%.1fs elapsed, target %.1fs)", + loop, + elapsed, + target_duration_seconds, + ) + break + logger.info("Benchmark completed %d loop(s), returncode: %d", loop, last_returncode) + combined_stdout = "".join(all_stdout) + combined_stderr = "".join(all_stderr) + proc_result = subprocess.CompletedProcess( + args=cmd, returncode=last_returncode, stdout=combined_stdout, stderr=combined_stderr + ) + else: + logger.warning("No Benchmark* functions found in perf test files: %s", [str(p) for p in test_file_paths]) + proc_result = subprocess.CompletedProcess(args=[], returncode=0, stdout="", stderr="") + + json_output_file.write_text(proc_result.stdout or "", encoding="utf-8") + + return json_output_file, proc_result + + +def parse_go_test_json(json_output: str) -> list[TestResult]: + results: dict[str, TestResult] = {} + + for line in json_output.splitlines(): + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + except json.JSONDecodeError: + continue + + action = event.get("Action") + test_name = event.get("Test") + if test_name is None: + continue + + package = event.get("Package", "") + + if action == "pass": + elapsed = event.get("Elapsed", 0) + results[test_name] = TestResult( + test_name=test_name, + test_file=_package_to_path(package), + passed=True, + runtime_ns=int(elapsed * 1_000_000_000) if elapsed else None, + ) + elif action == "fail": + elapsed = event.get("Elapsed", 0) + existing = results.get(test_name) + stdout = existing.stdout if existing else "" + results[test_name] = TestResult( + test_name=test_name, + test_file=_package_to_path(package), + passed=False, + runtime_ns=int(elapsed * 1_000_000_000) if elapsed else None, + stdout=stdout, + error_message=f"Test {test_name} failed", + ) + elif action == "output": + output_text = event.get("Output", "") + if test_name in results: + results[test_name] = TestResult( + test_name=results[test_name].test_name, + test_file=results[test_name].test_file, + passed=results[test_name].passed, + runtime_ns=results[test_name].runtime_ns, + stdout=results[test_name].stdout + output_text, + stderr=results[test_name].stderr, + error_message=results[test_name].error_message, + ) + else: + results[test_name] = TestResult( + test_name=test_name, test_file=_package_to_path(package), passed=True, stdout=output_text + ) + + return list(results.values()) + + +def parse_test_results(json_output_path: Path, stdout: str) -> list[TestResult]: + try: + content = json_output_path.read_text(encoding="utf-8") + except Exception: + content = stdout + return parse_go_test_json(content) + + +def _package_to_path(package: str) -> Path: + from pathlib import Path as _Path + + if package: + return _Path(package.replace("/", os.sep)) + return _Path() + + +def _collect_test_file_paths(test_paths: Any, *, use_benchmarking: bool = False) -> list[Path]: + from pathlib import Path as _Path + + if test_paths is None: + return [] + + if hasattr(test_paths, "test_files"): + paths = [] + for tf in test_paths.test_files: + if use_benchmarking: + p = getattr(tf, "benchmarking_file_path", None) or getattr(tf, "perf_file_path", None) + else: + p = getattr(tf, "instrumented_behavior_file_path", None) + if p is None: + p = getattr(tf, "original_file_path", None) + if p is not None: + paths.append(_Path(p)) + return paths + + if isinstance(test_paths, list): + return [_Path(p) for p in test_paths] + + return [] + + +def _collect_other_test_files(test_file_paths: list[Path]) -> list[Path]: + + if not test_file_paths: + return [] + + keep = {f.resolve() for f in test_file_paths} + dirs = {f.resolve().parent for f in test_file_paths} + + others: list[Path] = [] + for d in dirs: + for f in d.glob("*_test.go"): + if f.resolve() not in keep and f.exists(): + others.append(f) + return others + + +@contextlib.contextmanager +def _hide_other_test_files(others: list[Path]) -> Generator[None, None, None]: + """Temporarily rename test files we don't want compiled. + + Go compiles ALL *_test.go files in a package together, so any duplicate + symbols across test files cause build errors. We hide every test file in + the target directories except the ones we intend to run. + """ + renamed: list[tuple[Path, Path]] = [] + for f in others: + hidden = f.with_suffix(".go.codeflash_hidden") + try: + f.rename(hidden) + renamed.append((hidden, f)) + logger.debug("Temporarily hid %s during go test", f) + except OSError: + logger.debug("Could not hide %s, skipping", f) + try: + yield + finally: + for hidden, original in renamed: + try: + hidden.rename(original) + logger.debug("Restored %s", original) + except OSError: + logger.warning("Failed to restore %s from %s", original, hidden) + + +_TEST_FUNC_RE = re.compile(r"^func\s+(Test\w+)\s*\(", re.MULTILINE) +_BENCH_FUNC_RE = re.compile(r"^func\s+(Benchmark\w+)\s*\(", re.MULTILINE) +_FUNC_DECL_RE = re.compile(r"^(func\s+)(Test\w+|Benchmark\w+)(\s*\()", re.MULTILINE) + + +def _extract_func_names(test_files: list[Path], pattern: re.Pattern[str]) -> list[str]: + names: list[str] = [] + for f in test_files: + try: + content = f.read_text(encoding="utf-8") + except OSError: + continue + names.extend(pattern.findall(content)) + return names + + +def _build_run_regex(test_files: list[Path]) -> str | None: + names = _extract_func_names(test_files, _TEST_FUNC_RE) + if not names: + return None + return f"^({'|'.join(re.escape(n) for n in names)})$" + + +def _build_bench_regex(test_files: list[Path]) -> str | None: + names = _extract_func_names(test_files, _BENCH_FUNC_RE) + if not names: + return None + return f"^({'|'.join(re.escape(n) for n in names)})$" + + +def _deduplicate_test_func_names(test_files: list[Path]) -> dict[Path, str]: + seen: dict[str, int] = {} + originals: dict[Path, str] = {} + + for f in test_files: + try: + content = f.read_text(encoding="utf-8") + except OSError: + continue + + names_in_file = [name for _, name, _ in _FUNC_DECL_RE.findall(content)] + if not names_in_file: + continue + + needs_rewrite = any(name in seen for name in names_in_file) + + if not needs_rewrite: + for name in names_in_file: + seen[name] = 1 + continue + + originals[f] = content + + def _renamer(m: re.Match[str]) -> str: + prefix, name, suffix = m.group(1), m.group(2), m.group(3) + if name not in seen: + seen[name] = 1 + return m.group(0) + idx = seen[name] + seen[name] = idx + 1 + return f"{prefix}{name}_{idx}{suffix}" + + new_content = _FUNC_DECL_RE.sub(_renamer, content) + f.write_text(new_content, encoding="utf-8") + logger.debug("Deduplicated test function names in %s", f) + + return originals + + +@contextlib.contextmanager +def _deduplicated_test_files(test_files: list[Path]) -> Generator[None, None, None]: + originals = _deduplicate_test_func_names(test_files) + try: + yield + finally: + for f, content in originals.items(): + try: + f.write_text(content, encoding="utf-8") + except OSError: + logger.warning("Failed to restore original content for %s", f) + + +def _test_files_to_packages(test_files: list[Path], cwd: Path) -> list[str]: + dirs: set[str] = set() + resolved_cwd = cwd.resolve() + for f in test_files: + try: + rel = f.resolve().parent.relative_to(resolved_cwd) + pkg = f"./{rel.as_posix()}" if rel.parts else "." + dirs.add(pkg) + except ValueError: + continue + return sorted(dirs) if dirs else [] + + +def _run_cmd_kill_pg_on_timeout( + cmd: list[str], *, cwd: Path | None = None, env: dict[str, str] | None = None, timeout: int | None = None +) -> subprocess.CompletedProcess[str]: + if sys.platform == "win32": + try: + return subprocess.run(cmd, cwd=cwd, env=env, capture_output=True, text=True, timeout=timeout, check=False) + except subprocess.TimeoutExpired: + return subprocess.CompletedProcess( + args=cmd, returncode=-2, stdout="", stderr=f"Process timed out after {timeout}s" + ) + + proc = subprocess.Popen( + cmd, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, start_new_session=True + ) + try: + stdout, stderr = proc.communicate(timeout=timeout) + return subprocess.CompletedProcess(args=cmd, returncode=proc.returncode, stdout=stdout, stderr=stderr) + except subprocess.TimeoutExpired: + pgid = None + try: + pgid = os.getpgid(proc.pid) + os.killpg(pgid, signal.SIGTERM) + except (ProcessLookupError, OSError): + proc.kill() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + if pgid is not None: + with contextlib.suppress(ProcessLookupError, OSError): + os.killpg(pgid, signal.SIGKILL) + else: + proc.kill() + proc.wait() + try: + stdout_data = proc.stdout.read() if proc.stdout else "" + stderr_data = proc.stderr.read() if proc.stderr else "" + except Exception: + stdout_data, stderr_data = "", "" + return subprocess.CompletedProcess( + args=cmd, returncode=-2, stdout=stdout_data, stderr=stderr_data or f"Process timed out after {timeout}s" + ) diff --git a/codeflash/languages/language_enum.py b/codeflash/languages/language_enum.py index 23187cb30..4b72db62b 100644 --- a/codeflash/languages/language_enum.py +++ b/codeflash/languages/language_enum.py @@ -13,6 +13,7 @@ class Language(str, Enum): JAVASCRIPT = "javascript" TYPESCRIPT = "typescript" JAVA = "java" + GO = "go" def __str__(self) -> str: return self.value diff --git a/codeflash/languages/registry.py b/codeflash/languages/registry.py index 17a528fae..e151a5e5c 100644 --- a/codeflash/languages/registry.py +++ b/codeflash/languages/registry.py @@ -54,6 +54,7 @@ def _ensure_languages_registered() -> None: "codeflash.languages.python.support", "codeflash.languages.javascript.support", "codeflash.languages.java.support", + "codeflash.languages.golang.support", ): with contextlib.suppress(ImportError): importlib.import_module(_lang_module) @@ -227,11 +228,14 @@ def get_language_support_by_common_formatters(formatter_cmd: str | list[str]) -> py_formatters = ["black", "isort", "ruff", "autopep8", "yapf", "pyfmt"] js_ts_formatters = ["prettier", "eslint", "biome", "rome", "deno", "standard", "tslint"] + go_formatters = ["gofmt", "goimports", "golines"] if any(cmd in py_formatters for cmd in formatter_cmd): ext = ".py" elif any(cmd in js_ts_formatters for cmd in formatter_cmd): ext = ".js" + elif any(cmd in go_formatters for cmd in formatter_cmd): + ext = ".go" if ext is None: # can't determine language diff --git a/codeflash/setup/detector.py b/codeflash/setup/detector.py index 216dd669d..a84f26b72 100644 --- a/codeflash/setup/detector.py +++ b/codeflash/setup/detector.py @@ -172,6 +172,7 @@ def _find_project_root(start_path: Path) -> Path | None: "pom.xml", "build.gradle", "build.gradle.kts", + "go.mod", ] for marker in markers: if (current / marker).exists(): @@ -203,6 +204,11 @@ def _detect_language(project_root: Path) -> tuple[str, float, str]: has_package_json = (project_root / "package.json").exists() has_pom_xml = (project_root / "pom.xml").exists() has_build_gradle = (project_root / "build.gradle").exists() or (project_root / "build.gradle.kts").exists() + has_go_mod = (project_root / "go.mod").exists() + + # Go (go.mod is definitive) + if has_go_mod: + return "go", 1.0, "go.mod found" # Java (pom.xml or build.gradle is definitive) if has_pom_xml: @@ -235,7 +241,10 @@ def _detect_language(project_root: Path) -> tuple[str, float, str]: js_count = len(list(project_root.rglob("*.js"))) ts_count = len(list(project_root.rglob("*.ts"))) java_count = len(list(project_root.rglob("*.java"))) + go_count = len(list(project_root.rglob("*.go"))) + if go_count > 0 and go_count >= max(py_count, js_count, ts_count, java_count): + return "go", 0.5, f"found {go_count} .go files" if java_count > 0 and java_count >= max(py_count, js_count, ts_count): return "java", 0.5, f"found {java_count} .java files" if ts_count > 0: @@ -264,6 +273,8 @@ def _detect_module_root(project_root: Path, language: str) -> tuple[Path, str]: return _detect_js_module_root(project_root) if language == "java": return _detect_java_module_root(project_root) + if language == "go": + return _detect_go_module_root(project_root) return _detect_python_module_root(project_root) @@ -441,6 +452,23 @@ def _detect_java_module_root(project_root: Path) -> tuple[Path, str]: return project_root, "project root" +def _detect_go_module_root(project_root: Path) -> tuple[Path, str]: + """Detect Go module root directory. + + Go projects use go.mod at the module root. The source directory is the + same as the module root (Go packages are directories, not subdirectories). + """ + if (project_root / "go.mod").exists(): + return project_root, "project root (go.mod found)" + + # Check common subdirectories + for subdir in ["cmd", "pkg", "internal"]: + if (project_root / subdir).is_dir(): + return project_root, f"project root ({subdir}/ found)" + + return project_root, "project root" + + def is_build_output_dir(path: Path) -> bool: """Check if a path is within a common build output directory. @@ -474,6 +502,13 @@ def _detect_tests_root(project_root: Path, language: str) -> tuple[Path | None, - spec/ (Ruby/JavaScript) """ + # Go: tests are co-located with source files (*_test.go) + if language == "go": + test_files = list(project_root.rglob("*_test.go")) + if test_files: + return project_root, "project root (Go tests co-located with source)" + return project_root, "project root (Go convention: *_test.go)" + # Java: standard Maven/Gradle test layout if language == "java": import xml.etree.ElementTree as ET @@ -558,6 +593,8 @@ def _detect_test_runner(project_root: Path, language: str) -> tuple[str, str]: return _detect_js_test_runner(project_root) if language == "java": return _detect_java_test_runner(project_root) + if language == "go": + return "go-test", "go test (built-in)" return _detect_python_test_runner(project_root) @@ -686,6 +723,8 @@ def _detect_formatter(project_root: Path, language: str) -> tuple[list[str], str return _detect_js_formatter(project_root) if language == "java": return _detect_java_formatter(project_root) + if language == "go": + return _detect_go_formatter(project_root) return _detect_python_formatter(project_root) @@ -803,6 +842,23 @@ def _detect_js_formatter(project_root: Path) -> tuple[list[str], str]: return [], "none detected" +def _detect_go_formatter(project_root: Path) -> tuple[list[str], str]: + """Detect Go formatter. + + Go has a universal formatter (gofmt). goimports is preferred if available + because it also manages imports. + """ + from codeflash.languages.golang.formatter import _find_go_tool + + goimports = _find_go_tool("goimports") + if goimports: + return [f"{goimports} -w $file"], "goimports (auto-detected)" + gofmt = _find_go_tool("gofmt") + if gofmt: + return [f"{gofmt} -w $file"], "gofmt (auto-detected)" + return ["gofmt -w $file"], "gofmt (default)" + + def _detect_ignore_paths(project_root: Path, language: str) -> tuple[list[Path], str]: """Detect paths to ignore during optimization. @@ -836,6 +892,7 @@ def _detect_ignore_paths(project_root: Path, language: str) -> tuple[list[Path], "javascript": ["node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache"], "typescript": ["node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache"], "java": ["target", "build", ".gradle", ".idea", "out"], + "go": ["vendor", "testdata"], } # Add default ignores @@ -900,6 +957,10 @@ def has_existing_config(project_root: Path) -> tuple[bool, str | None]: except Exception: pass + # Check Go projects — go.mod presence means "configured" + if (project_root / "go.mod").exists(): + return True, "go.mod" + # Check Java build files — zero-config: build file presence means "configured" for build_file in ("pom.xml", "build.gradle", "build.gradle.kts"): if (project_root / build_file).exists(): diff --git a/codeflash/version.py b/codeflash/version.py index 0f1baf8bc..226fdf7ad 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.20.5.post151.dev0+95b62113" +__version__ = "0.20.5" diff --git a/pyproject.toml b/pyproject.toml index 0a14b35e5..dc73f9917 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "tree-sitter-typescript>=0.23.2", "tree-sitter-java>=0.23.5", "tree-sitter-groovy>=0.1.2", + "tree-sitter-go>=0.23.0", "tree-sitter-kotlin>=1.1.0", "pytest-timeout>=2.4.0", "tomlkit>=0.14.0", diff --git a/tests/test_languages/fixtures/go_project/calculator.go b/tests/test_languages/fixtures/go_project/calculator.go new file mode 100644 index 000000000..787a45f73 --- /dev/null +++ b/tests/test_languages/fixtures/go_project/calculator.go @@ -0,0 +1,53 @@ +package calculator + +import "math" + +// Add returns the sum of two integers. +func Add(a, b int) int { + return a + b +} + +func Subtract(a, b int) int { + return a - b +} + +// unexported function +func multiply(a, b int) int { + return a * b +} + +// no return type +func init() { + // package initialization +} + +func Fibonacci(n int) int { + if n <= 1 { + return n + } + return Fibonacci(n-1) + Fibonacci(n-2) +} + +// Hypotenuse calculates the hypotenuse of a right triangle. +func Hypotenuse(a, b float64) float64 { + return math.Sqrt(a*a + b*b) +} + +type Calculator struct { + Result float64 +} + +// AddFloat adds a value to the calculator result. +func (c *Calculator) AddFloat(val float64) float64 { + c.Result += val + return c.Result +} + +func (c Calculator) GetResult() float64 { + return c.Result +} + +// Reset zeroes the calculator. +func (c *Calculator) Reset() { + c.Result = 0 +} diff --git a/tests/test_languages/fixtures/go_project/calculator_test.go b/tests/test_languages/fixtures/go_project/calculator_test.go new file mode 100644 index 000000000..c8e6e4d66 --- /dev/null +++ b/tests/test_languages/fixtures/go_project/calculator_test.go @@ -0,0 +1,34 @@ +package calculator + +import "testing" + +func TestAdd(t *testing.T) { + result := Add(2, 3) + if result != 5 { + t.Errorf("Add(2, 3) = %d; want 5", result) + } +} + +func TestSubtract(t *testing.T) { + result := Subtract(5, 3) + if result != 2 { + t.Errorf("Subtract(5, 3) = %d; want 2", result) + } +} + +func TestFibonacci(t *testing.T) { + tests := []struct { + input int + expected int + }{ + {0, 0}, + {1, 1}, + {10, 55}, + } + for _, tt := range tests { + result := Fibonacci(tt.input) + if result != tt.expected { + t.Errorf("Fibonacci(%d) = %d; want %d", tt.input, result, tt.expected) + } + } +} diff --git a/tests/test_languages/fixtures/go_project/go.mod b/tests/test_languages/fixtures/go_project/go.mod new file mode 100644 index 000000000..910687fdd --- /dev/null +++ b/tests/test_languages/fixtures/go_project/go.mod @@ -0,0 +1,7 @@ +module github.com/example/myproject + +go 1.22.0 + +require ( + github.com/stretchr/testify v1.9.0 +) diff --git a/tests/test_languages/test_golang/__init__.py b/tests/test_languages/test_golang/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_languages/test_golang/test_comparator.py b/tests/test_languages/test_golang/test_comparator.py new file mode 100644 index 000000000..43bb4631f --- /dev/null +++ b/tests/test_languages/test_golang/test_comparator.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.golang.comparator import compare_test_results + + +class TestCompareTestResults: + def test_equivalent_results(self, tmp_path: Path) -> None: + orig = (tmp_path / "original.jsonl").resolve() + cand = (tmp_path / "candidate.jsonl").resolve() + orig.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n' + '{"Action":"pass","Test":"TestSub","Package":"calc"}\n', + encoding="utf-8", + ) + cand.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n' + '{"Action":"pass","Test":"TestSub","Package":"calc"}\n', + encoding="utf-8", + ) + eq, diffs = compare_test_results(orig, cand) + assert eq is True + assert diffs == [] + + def test_candidate_fails_one(self, tmp_path: Path) -> None: + orig = (tmp_path / "original.jsonl").resolve() + cand = (tmp_path / "candidate.jsonl").resolve() + orig.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n' + '{"Action":"pass","Test":"TestSub","Package":"calc"}\n', + encoding="utf-8", + ) + cand.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n' + '{"Action":"fail","Test":"TestSub","Package":"calc"}\n', + encoding="utf-8", + ) + eq, diffs = compare_test_results(orig, cand) + assert eq is False + assert len(diffs) == 1 + assert diffs[0].test_name == "TestSub" + assert diffs[0].original_passed is True + assert diffs[0].candidate_passed is False + + def test_missing_test_in_candidate(self, tmp_path: Path) -> None: + orig = (tmp_path / "original.jsonl").resolve() + cand = (tmp_path / "candidate.jsonl").resolve() + orig.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n' + '{"Action":"pass","Test":"TestSub","Package":"calc"}\n', + encoding="utf-8", + ) + cand.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n', + encoding="utf-8", + ) + eq, diffs = compare_test_results(orig, cand) + assert eq is False + assert len(diffs) == 1 + assert diffs[0].test_name == "TestSub" + + def test_extra_test_in_candidate(self, tmp_path: Path) -> None: + orig = (tmp_path / "original.jsonl").resolve() + cand = (tmp_path / "candidate.jsonl").resolve() + orig.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n', + encoding="utf-8", + ) + cand.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n' + '{"Action":"pass","Test":"TestNew","Package":"calc"}\n', + encoding="utf-8", + ) + eq, diffs = compare_test_results(orig, cand) + assert eq is False + assert len(diffs) == 1 + assert diffs[0].test_name == "TestNew" + + def test_both_empty(self, tmp_path: Path) -> None: + orig = (tmp_path / "original.jsonl").resolve() + cand = (tmp_path / "candidate.jsonl").resolve() + orig.write_text("", encoding="utf-8") + cand.write_text("", encoding="utf-8") + eq, diffs = compare_test_results(orig, cand) + assert eq is True + assert diffs == [] + + def test_missing_files(self, tmp_path: Path) -> None: + orig = (tmp_path / "missing1.jsonl").resolve() + cand = (tmp_path / "missing2.jsonl").resolve() + eq, diffs = compare_test_results(orig, cand) + assert eq is True + assert diffs == [] + + def test_both_fail_same_test(self, tmp_path: Path) -> None: + orig = (tmp_path / "original.jsonl").resolve() + cand = (tmp_path / "candidate.jsonl").resolve() + orig.write_text( + '{"Action":"fail","Test":"TestBroken","Package":"calc"}\n', + encoding="utf-8", + ) + cand.write_text( + '{"Action":"fail","Test":"TestBroken","Package":"calc"}\n', + encoding="utf-8", + ) + eq, diffs = compare_test_results(orig, cand) + assert eq is True + assert diffs == [] diff --git a/tests/test_languages/test_golang/test_config.py b/tests/test_languages/test_golang/test_config.py new file mode 100644 index 000000000..c42e3cada --- /dev/null +++ b/tests/test_languages/test_golang/test_config.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.golang.config import detect_go_project, is_go_project + +FIXTURES_DIR = Path(__file__).parent.parent / "fixtures" / "go_project" + + +class TestDetectGoProject: + def test_detects_project(self) -> None: + config = detect_go_project(FIXTURES_DIR) + assert config is not None + assert config.module_path == "github.com/example/myproject" + assert config.go_version == "1.22.0" + + def test_no_go_mod(self, tmp_path: Path) -> None: + config = detect_go_project(tmp_path) + assert config is None + + def test_minimal_go_mod(self, tmp_path: Path) -> None: + go_mod = tmp_path / "go.mod" + go_mod.write_text("module example.com/minimal\n\ngo 1.21\n", encoding="utf-8") + config = detect_go_project(tmp_path) + assert config is not None + assert config.module_path == "example.com/minimal" + assert config.go_version == "1.21" + + def test_vendor_detection(self, tmp_path: Path) -> None: + go_mod = tmp_path / "go.mod" + go_mod.write_text("module example.com/vendored\n\ngo 1.22\n", encoding="utf-8") + (tmp_path / "vendor").mkdir() + config = detect_go_project(tmp_path) + assert config is not None + assert config.has_vendor is True + + +class TestIsGoProject: + def test_with_go_mod(self) -> None: + assert is_go_project(FIXTURES_DIR) is True + + def test_without_go_files(self, tmp_path: Path) -> None: + assert is_go_project(tmp_path) is False + + def test_with_go_files_no_mod(self, tmp_path: Path) -> None: + (tmp_path / "main.go").write_text("package main\n", encoding="utf-8") + assert is_go_project(tmp_path) is True diff --git a/tests/test_languages/test_golang/test_context.py b/tests/test_languages/test_golang/test_context.py new file mode 100644 index 000000000..705ea2d4b --- /dev/null +++ b/tests/test_languages/test_golang/test_context.py @@ -0,0 +1,374 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.base import Language +from codeflash.languages.golang.context import extract_code_context, find_helper_functions +from codeflash.models.function_types import FunctionParent, FunctionToOptimize + +GO_SOURCE_WITH_METHOD = """\ +package calc + +import "math" + +type Calculator struct { +\tResult float64 +} + +// Add returns the sum. +func Add(a, b int) int { +\treturn a + b +} + +func subtract(a, b int) int { +\treturn a - b +} + +func (c *Calculator) AddFloat(val float64) float64 { +\tc.Result += val +\treturn c.Result +} +""" + + +class TestExtractCodeContextFunction: + def test_target_code_with_doc_comment(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.target_code == "// Add returns the sum.\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + + def test_target_code_no_doc(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="subtract", file_path=source_file, language="go", starting_line=14, ending_line=16 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.target_code == "func subtract(a, b int) int {\n\treturn a - b\n}" + + def test_imports_extracted(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.imports == ['"math"'] + + def test_no_read_only_context_for_function(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.read_only_context == "" + + def test_helpers_only_includes_called_functions(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.helper_functions == [] + + def test_helpers_includes_called_function(self, tmp_path: Path) -> None: + source = ( + "package calc\n\n" + "func helper(x int) int { return x * 2 }\n\n" + "func Target(a int) int { return helper(a) }\n" + ) + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(source, encoding="utf-8") + func = FunctionToOptimize(function_name="Target", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + helper_names = [h.name for h in ctx.helper_functions] + assert helper_names == ["helper"] + + def test_language_is_go(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.language == Language.GO + + def test_target_file_path(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.target_file == source_file + + +class TestExtractCodeContextMethod: + def test_method_target_code(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="AddFloat", + file_path=source_file, + parents=[FunctionParent(name="Calculator", type="StructDef")], + language="go", + is_method=True, + starting_line=18, + ending_line=21, + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.target_code == ( + "func (c *Calculator) AddFloat(val float64) float64 {\n" + "\tc.Result += val\n" + "\treturn c.Result\n" + "}" + ) + + def test_method_read_only_context_is_struct(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="AddFloat", + file_path=source_file, + parents=[FunctionParent(name="Calculator", type="StructDef")], + language="go", + is_method=True, + starting_line=18, + ending_line=21, + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.read_only_context == "type Calculator struct {\n\tResult float64\n}" + + def test_method_helpers_exclude_self(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="AddFloat", + file_path=source_file, + parents=[FunctionParent(name="Calculator", type="StructDef")], + language="go", + is_method=True, + starting_line=18, + ending_line=21, + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.helper_functions == [] + + def test_method_helpers_with_calls(self, tmp_path: Path) -> None: + source = ( + "package calc\n\n" + "type Calc struct{ Val int }\n\n" + "func double(x int) int { return x * 2 }\n\n" + "func (c *Calc) Compute() int { return double(c.Val) }\n" + ) + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(source, encoding="utf-8") + func = FunctionToOptimize( + function_name="Compute", + file_path=source_file, + parents=[FunctionParent(name="Calc", type="StructDef")], + language="go", + is_method=True, + ) + ctx = extract_code_context(func, tmp_path.resolve()) + helper_names = [h.name for h in ctx.helper_functions] + assert helper_names == ["double"] + assert "Compute" not in helper_names + + +class TestExtractCodeContextEdgeCases: + def test_missing_file(self, tmp_path: Path) -> None: + missing = (tmp_path / "missing.go").resolve() + func = FunctionToOptimize(function_name="Foo", file_path=missing, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.target_code == "" + assert ctx.language == Language.GO + + def test_function_not_in_source(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text("package calc\n\nfunc Other() {}\n", encoding="utf-8") + func = FunctionToOptimize(function_name="Missing", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.target_code == "" + + def test_multi_import(self, tmp_path: Path) -> None: + source = 'package calc\n\nimport (\n\t"fmt"\n\t"os"\n\tstr "strings"\n)\n\nfunc Hello() string {\n\treturn "hi"\n}\n' + source_file = (tmp_path / "hello.go").resolve() + source_file.write_text(source, encoding="utf-8") + func = FunctionToOptimize(function_name="Hello", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.imports == ['"fmt"', '"os"', 'str "strings"'] + + +GO_SOURCE_WITH_INIT = """\ +package server + +import "sync" + +var ( +\tglobalCache map[string]int +\tmu sync.Mutex +) + +const MaxRetries = 5 + +type Config struct { +\tName string +\tMax int +} + +func init() { +\tglobalCache = make(map[string]int) +\tglobalCache["default"] = 0 +\tmu.Lock() +\tmu.Unlock() +} + +func Process() int { +\treturn MaxRetries +} +""" + + +class TestExtractCodeContextWithInit: + def test_init_in_read_only_context(self, tmp_path: Path) -> None: + source_file = (tmp_path / "server.go").resolve() + source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8") + func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + assert "func init()" in ctx.read_only_context + + def test_init_referenced_globals_in_read_only_context(self, tmp_path: Path) -> None: + source_file = (tmp_path / "server.go").resolve() + source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8") + func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + assert "globalCache" in ctx.read_only_context + assert "mu" in ctx.read_only_context + + def test_init_not_in_helpers(self, tmp_path: Path) -> None: + source_file = (tmp_path / "server.go").resolve() + source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8") + func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + helper_names = [h.name for h in ctx.helper_functions] + assert "init" not in helper_names + + def test_no_init_no_extra_context(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize(function_name="Add", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + assert "func init()" not in ctx.read_only_context + + def test_full_init_read_only_context(self, tmp_path: Path) -> None: + source_file = (tmp_path / "server.go").resolve() + source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8") + func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + expected = ( + "var (\n" + "\tglobalCache map[string]int\n" + "\tmu sync.Mutex\n" + ")\n" + "\n" + "func init() {\n" + "\tglobalCache = make(map[string]int)\n" + "\tglobalCache[\"default\"] = 0\n" + "\tmu.Lock()\n" + "\tmu.Unlock()\n" + "}" + ) + assert ctx.read_only_context == expected + + def test_method_with_init_combines_struct_and_init_context(self, tmp_path: Path) -> None: + source = """\ +package server + +var globalOffset = 10 + +type Calc struct { +\tVal int +} + +func init() { +\tglobalOffset = 42 +} + +func (c *Calc) Compute() int { +\treturn c.Val + globalOffset +} +""" + source_file = (tmp_path / "server.go").resolve() + source_file.write_text(source, encoding="utf-8") + func = FunctionToOptimize( + function_name="Compute", + file_path=source_file, + parents=[FunctionParent(name="Calc", type="StructDef")], + language="go", + is_method=True, + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert "type Calc struct" in ctx.read_only_context + assert "func init()" in ctx.read_only_context + assert "var globalOffset = 10" in ctx.read_only_context + + +class TestFindHelperFunctions: + def test_skips_init_and_main(self, tmp_path: Path) -> None: + source = "package main\n\nfunc init() { println() }\n\nfunc main() { println() }\n\nfunc Target() int { return 1 }\n" + source_file = (tmp_path / "main.go").resolve() + func = FunctionToOptimize(function_name="Target", file_path=source_file, language="go") + helpers = find_helper_functions(source, func) + helper_names = [h.name for h in helpers] + assert "init" not in helper_names + assert "main" not in helper_names + + def test_method_helpers_have_qualified_names(self, tmp_path: Path) -> None: + source = ( + "package calc\n\n" + "type Calc struct{}\n\n" + "func (c Calc) Target() int { return c.Helper() }\n\n" + "func (c Calc) Helper() int { return 2 }\n" + ) + source_file = (tmp_path / "calc.go").resolve() + func = FunctionToOptimize( + function_name="Target", + file_path=source_file, + parents=[FunctionParent(name="Calc", type="StructDef")], + language="go", + is_method=True, + ) + helpers = find_helper_functions(source, func) + assert len(helpers) == 1 + assert helpers[0].qualified_name == "Calc.Helper" + + def test_transitive_helpers(self, tmp_path: Path) -> None: + source = ( + "package calc\n\n" + "func innerHelper(x int) int { return x }\n\n" + "func outerHelper(x int) int { return innerHelper(x) }\n\n" + "func Target(a int) int { return outerHelper(a) }\n" + ) + source_file = (tmp_path / "calc.go").resolve() + func = FunctionToOptimize(function_name="Target", file_path=source_file, language="go") + helpers = find_helper_functions(source, func) + helper_names = sorted(h.name for h in helpers) + assert helper_names == ["innerHelper", "outerHelper"] + + def test_uncalled_functions_excluded(self, tmp_path: Path) -> None: + source = ( + "package calc\n\n" + "func unrelated() int { return 99 }\n\n" + "func Target(a int) int { return a + 1 }\n" + ) + source_file = (tmp_path / "calc.go").resolve() + func = FunctionToOptimize(function_name="Target", file_path=source_file, language="go") + helpers = find_helper_functions(source, func) + assert helpers == [] diff --git a/tests/test_languages/test_golang/test_discovery.py b/tests/test_languages/test_golang/test_discovery.py new file mode 100644 index 000000000..19d05e6c0 --- /dev/null +++ b/tests/test_languages/test_golang/test_discovery.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.base import FunctionFilterCriteria +from codeflash.languages.golang.discovery import discover_functions_from_source + +GO_SOURCE = """\ +package calculator + +import "math" + +// Add returns the sum of two integers. +func Add(a, b int) int { + return a + b +} + +func subtract(a, b int) int { + return a - b +} + +func init() { + println("setup") +} + +func main() { + println("hello") +} + +func noReturn() { + println("hello") +} + +type Calculator struct { + Result float64 +} + +func (c *Calculator) AddFloat(val float64) float64 { + c.Result += val + return c.Result +} + +func (c Calculator) GetResult() float64 { + return c.Result +} + +func Hypotenuse(a, b float64) float64 { + return math.Sqrt(a*a + b*b) +} +""" + +GO_TEST_SOURCE = """\ +package calculator + +import "testing" + +func TestAdd(t *testing.T) { + result := Add(2, 3) + if result != 5 { + t.Errorf("want 5, got %d", result) + } +} +""" + + +class TestDiscoverFunctions: + def test_discovers_exported_functions(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + names = [f.function_name for f in results] + assert "Add" in names + assert "Hypotenuse" in names + + def test_discovers_unexported_functions(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + names = [f.function_name for f in results] + assert "subtract" in names + assert "noReturn" in names + + def test_skips_init_and_main(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + names = [f.function_name for f in results] + assert "init" not in names + assert "main" not in names + + def test_skips_test_files(self) -> None: + results = discover_functions_from_source(GO_TEST_SOURCE, Path("/project/calc_test.go")) + assert len(results) == 0 + + def test_discovers_methods(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + methods = [f for f in results if f.is_method] + assert len(methods) == 2 + names = [m.function_name for m in methods] + assert "AddFloat" in names + assert "GetResult" in names + + def test_method_parents(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + method = next(f for f in results if f.function_name == "AddFloat") + assert len(method.parents) == 1 + assert method.parents[0].name == "Calculator" + assert method.parents[0].type == "StructDef" + + def test_language_is_go(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + for func in results: + assert func.language == "go" + + def test_is_async_false(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + for func in results: + assert func.is_async is False + + +class TestDiscoverWithFilters: + def test_filter_export_only(self) -> None: + criteria = FunctionFilterCriteria(require_export=True, require_return=False) + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go"), criteria) + names = [f.function_name for f in results] + assert "Add" in names + assert "Hypotenuse" in names + assert "subtract" not in names + assert "noReturn" not in names + + def test_filter_require_return(self) -> None: + criteria = FunctionFilterCriteria(require_export=False, require_return=True) + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go"), criteria) + names = [f.function_name for f in results] + assert "Add" in names + assert "noReturn" not in names + + def test_filter_exclude_methods(self) -> None: + criteria = FunctionFilterCriteria(require_export=False, require_return=False, include_methods=False) + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go"), criteria) + methods = [f for f in results if f.is_method] + assert len(methods) == 0 + + def test_filter_exclude_pattern(self) -> None: + criteria = FunctionFilterCriteria( + require_export=False, require_return=False, exclude_patterns=["subtract"] + ) + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go"), criteria) + names = [f.function_name for f in results] + assert "subtract" not in names + assert "Add" in names + + def test_filter_include_pattern(self) -> None: + criteria = FunctionFilterCriteria( + require_export=False, require_return=False, include_patterns=["Add*"] + ) + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go"), criteria) + names = [f.function_name for f in results] + assert "Add" in names + assert "AddFloat" in names + assert "subtract" not in names + assert "Hypotenuse" not in names + + def test_filter_min_lines(self) -> None: + criteria = FunctionFilterCriteria(require_export=False, require_return=False, min_lines=4) + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go"), criteria) + for func in results: + line_count = func.ending_line - func.starting_line + 1 + assert line_count >= 4 diff --git a/tests/test_languages/test_golang/test_formatter.py b/tests/test_languages/test_golang/test_formatter.py new file mode 100644 index 000000000..4665aaa6d --- /dev/null +++ b/tests/test_languages/test_golang/test_formatter.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from unittest.mock import patch + +from codeflash.languages.golang.formatter import format_go_code, normalize_go_code + + +class TestNormalizeGoCode: + def test_strips_line_comments(self) -> None: + source = "package calc\n\n// Add returns the sum.\nfunc Add(a, b int) int {\n\treturn a + b // fast path\n}\n" + result = normalize_go_code(source) + expected = "package calc\nfunc Add(a, b int) int {\nreturn a + b\n}" + assert result == expected + + def test_strips_single_line_block_comment(self) -> None: + source = "package calc\n\n/* block comment */\nfunc Subtract(a, b int) int {\n\treturn a - b\n}\n" + result = normalize_go_code(source) + expected = "package calc\nfunc Subtract(a, b int) int {\nreturn a - b\n}" + assert result == expected + + def test_strips_multi_line_block_comment(self) -> None: + source = "package calc\n\n/*\nThis is a\nmulti-line comment.\n*/\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + result = normalize_go_code(source) + expected = "package calc\nfunc Add(a, b int) int {\nreturn a + b\n}" + assert result == expected + + def test_preserves_comment_in_string(self) -> None: + source = 'func Greet() string {\n\treturn "hello // world"\n}\n' + result = normalize_go_code(source) + expected = 'func Greet() string {\nreturn "hello // world"\n}' + assert result == expected + + def test_preserves_comment_in_raw_string(self) -> None: + source = "func Greet() string {\n\treturn `hello // world`\n}\n" + result = normalize_go_code(source) + expected = "func Greet() string {\nreturn `hello // world`\n}" + assert result == expected + + def test_strips_whitespace_and_empty_lines(self) -> None: + source = "package calc\n\n\n\nfunc Add(a, b int) int {\n\t\treturn a + b\n\t}\n" + result = normalize_go_code(source) + expected = "package calc\nfunc Add(a, b int) int {\nreturn a + b\n}" + assert result == expected + + def test_mixed_comments(self) -> None: + source = ( + "package calc\n\n" + "// Add returns the sum.\n" + "func Add(a, b int) int {\n" + "\treturn a + b // fast path\n" + "}\n\n" + "/* block comment */\n" + "func Subtract(a, b int) int {\n" + "\treturn a - b\n" + "}\n" + ) + result = normalize_go_code(source) + expected = ( + "package calc\nfunc Add(a, b int) int {\nreturn a + b\n}\nfunc Subtract(a, b int) int {\nreturn a - b\n}" + ) + assert result == expected + + def test_inline_block_comment(self) -> None: + source = "func Add(a /* first */, b int) int {\n\treturn a + b\n}\n" + result = normalize_go_code(source) + expected = "func Add(a , b int) int {\nreturn a + b\n}" + assert result == expected + + def test_empty_input(self) -> None: + assert normalize_go_code("") == "" + + def test_only_comments(self) -> None: + source = "// just a comment\n// another line\n" + result = normalize_go_code(source) + assert result == "" + + +class TestFormatGoCode: + def test_no_formatter_returns_source(self) -> None: + source = "package calc\n\nfunc Add(a, b int) int {\nreturn a+b\n}\n" + with patch("codeflash.languages.golang.formatter._find_go_tool", return_value=None): + result = format_go_code(source) + assert result == source + + def test_format_with_gofmt(self) -> None: + import shutil + + if shutil.which("gofmt") is None: + return + source = "package calc\n\nfunc Add(a,b int)int{\nreturn a+b\n}\n" + result = format_go_code(source) + assert result != source + assert "func Add" in result + + def test_format_failure_returns_source(self) -> None: + source = "this is not valid go" + with patch("codeflash.languages.golang.formatter.shutil.which", return_value="/usr/bin/gofmt"): + with patch("codeflash.languages.golang.formatter.subprocess.run") as mock_run: + mock_run.return_value.returncode = 2 + mock_run.return_value.stderr = "syntax error" + result = format_go_code(source) + assert result == source diff --git a/tests/test_languages/test_golang/test_function_optimizer.py b/tests/test_languages/test_golang/test_function_optimizer.py new file mode 100644 index 000000000..b76c82ba6 --- /dev/null +++ b/tests/test_languages/test_golang/test_function_optimizer.py @@ -0,0 +1,541 @@ +from __future__ import annotations + +import hashlib +from pathlib import Path +from textwrap import dedent +from typing import TYPE_CHECKING + +import pytest + +from codeflash.languages.golang.context import extract_code_context +from codeflash.languages.golang.function_optimizer import _build_optimization_context +from codeflash.models.function_types import FunctionParent, FunctionToOptimize + +if TYPE_CHECKING: + from codeflash.models.models import CodeOptimizationContext + +# --------------------------------------------------------------------------- +# Realistic Go sources used across test classes +# --------------------------------------------------------------------------- + +CALCULATOR_SOURCE = dedent("""\ + package calc + + import ( + \t"fmt" + \t"math" + \tstr "strings" + ) + + // Calculator holds running computation state. + type Calculator struct { + \tResult float64 + \tHistory []float64 + } + + // Formatter controls output rendering. + type Formatter interface { + \tFormat(val float64) string + } + + // Add returns the sum of two integers. + func Add(a, b int) int { + \treturn a + b + } + + func subtract(a, b int) int { + \treturn a - b + } + + func multiply(a, b int) int { + \treturn a * b + } + + // Greet builds a greeting message. + func Greet(name string) string { + \treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name)) + } + + // AddFloat adds a float value and records history. + func (c *Calculator) AddFloat(val float64) float64 { + \tc.Result += val + \tc.History = append(c.History, c.Result) + \treturn c.Result + } + + // Sqrt computes the square root of the current result. + func (c *Calculator) Sqrt() float64 { + \tc.Result = math.Sqrt(c.Result) + \tc.History = append(c.History, c.Result) + \treturn c.Result + } + + // Reset zeroes out the calculator. + func (c Calculator) Reset() Calculator { + \tc.Result = 0 + \tc.History = nil + \treturn c + } +""") + +SIMPLE_SOURCE = dedent("""\ + package simple + + func Double(x int) int { + \treturn x * 2 + } +""") + +INIT_SOURCE = dedent("""\ + package server + + import ( + \t"fmt" + \t"sync" + ) + + var ( + \tglobalCache map[string]int + \tmu sync.Mutex + ) + + var singleVar = 42 + + const MaxRetries = 5 + + type Config struct { + \tName string + \tMax int + } + + func init() { + \tglobalCache = make(map[string]int) + \tglobalCache["default"] = 0 + \tdefaultCfg := Config{Name: "prod", Max: MaxRetries} + \t_ = defaultCfg + \tmu.Lock() + \tmu.Unlock() + } + + func Process() int { + \tfmt.Println("processing") + \treturn singleVar + MaxRetries + } +""") + + +# --------------------------------------------------------------------------- +# Helpers to drive the full extract → build pipeline +# --------------------------------------------------------------------------- + + +def _build_context_for_function( + source: str, + filename: str, + function_name: str, + tmp_path: Path, + parents: list[FunctionParent] | None = None, + is_method: bool = False, +) -> CodeOptimizationContext: + root = tmp_path.resolve() + source_file = (root / filename).resolve() + source_file.write_text(source, encoding="utf-8") + + func = FunctionToOptimize( + function_name=function_name, file_path=source_file, parents=parents or [], language="go", is_method=is_method + ) + code_context = extract_code_context(func, root) + return _build_optimization_context(code_context, source_file, "go", root) + + +# --------------------------------------------------------------------------- +# Tests: targeting a plain exported function +# --------------------------------------------------------------------------- + + +class TestBuildContextExportedFunction: + """Target: Add(a, b int) int — a plain exported function with a doc comment.""" + + def test_full_assembled_code_string(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + code = result.read_writable_code.code_strings[0].code + + expected = dedent("""\ + package calc + + import ( + \t"fmt" + \t"math" + \tstr "strings" + ) + + // Add returns the sum of two integers. + func Add(a, b int) int { + \treturn a + b + } + """) + assert code == expected + + def test_code_includes_package_clause(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + code = result.read_writable_code.code_strings[0].code + assert code.startswith("package calc\n") + + def test_code_excludes_struct_definition(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + code = result.read_writable_code.code_strings[0].code + assert "type Calculator struct" not in code + + def test_code_excludes_interface_definition(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + code = result.read_writable_code.code_strings[0].code + assert "type Formatter interface" not in code + + def test_no_helpers_when_no_calls(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.helper_functions == [] + + def test_no_read_only_context_for_plain_function(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.read_only_context_code == "" + + def test_relative_path(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.read_writable_code.code_strings[0].file_path == Path("calc.go") + + def test_language_tag(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.read_writable_code.code_strings[0].language == "go" + + def test_testgen_fqns_match_helpers(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + fqns = set(result.testgen_helper_fqns) + helper_fqns = {h.fully_qualified_name for h in result.helper_functions} + assert fqns == helper_fqns + + +# --------------------------------------------------------------------------- +# Tests: targeting a method with a pointer receiver +# --------------------------------------------------------------------------- + + +class TestBuildContextPointerReceiverMethod: + """Target: (c *Calculator) AddFloat(val float64) — pointer receiver method.""" + + def _build(self, tmp_path: Path) -> CodeOptimizationContext: + return _build_context_for_function( + CALCULATOR_SOURCE, + "calc.go", + "AddFloat", + tmp_path, + parents=[FunctionParent(name="Calculator", type="StructDef")], + is_method=True, + ) + + def test_full_assembled_code_string(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + code = result.read_writable_code.code_strings[0].code + + expected = dedent("""\ + package calc + + import ( + \t"fmt" + \t"math" + \tstr "strings" + ) + + // AddFloat adds a float value and records history. + func (c *Calculator) AddFloat(val float64) float64 { + \tc.Result += val + \tc.History = append(c.History, c.Result) + \treturn c.Result + } + """) + assert code == expected + + def test_code_excludes_type_defs(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + code = result.read_writable_code.code_strings[0].code + assert "type Calculator struct" not in code + assert "type Formatter interface" not in code + + def test_read_only_context_is_struct_definition(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + assert result.read_only_context_code == dedent("""\ + type Calculator struct { + \tResult float64 + \tHistory []float64 + }""") + + def test_no_helpers_when_no_calls(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + assert result.helper_functions == [] + + def test_target_not_duplicated_in_code_string(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + code = result.read_writable_code.code_strings[0].code + assert code.count("func (c *Calculator) AddFloat") == 1 + + +# --------------------------------------------------------------------------- +# Tests: targeting a value receiver method +# --------------------------------------------------------------------------- + + +class TestBuildContextValueReceiverMethod: + """Target: (c Calculator) Reset() — value receiver method.""" + + def _build(self, tmp_path: Path) -> CodeOptimizationContext: + return _build_context_for_function( + CALCULATOR_SOURCE, + "calc.go", + "Reset", + tmp_path, + parents=[FunctionParent(name="Calculator", type="StructDef")], + is_method=True, + ) + + def test_target_in_code_string(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + code = result.read_writable_code.code_strings[0].code + + expected_target = dedent("""\ + // Reset zeroes out the calculator. + func (c Calculator) Reset() Calculator { + \tc.Result = 0 + \tc.History = nil + \treturn c + }""") + assert code.count("func (c Calculator) Reset()") == 1 + assert expected_target in code + + def test_no_helpers_when_no_calls(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + assert result.helper_functions == [] + + def test_no_helper_code_in_assembled_string(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + code = result.read_writable_code.code_strings[0].code + assert "func (c *Calculator) AddFloat" not in code + assert "func Add(a, b int) int" not in code + + def test_struct_in_read_only_context(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + assert result.read_only_context_code == dedent("""\ + type Calculator struct { + \tResult float64 + \tHistory []float64 + }""") + + +# --------------------------------------------------------------------------- +# Tests: simple source with no imports, no methods, one function +# --------------------------------------------------------------------------- + + +class TestBuildContextMinimalSource: + """Target: Double(x int) — minimal file with no imports or structs.""" + + def test_no_imports_package_only_prefix(self, tmp_path: Path) -> None: + result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path) + code = result.read_writable_code.code_strings[0].code + assert code == dedent("""\ + package simple + + func Double(x int) int { + \treturn x * 2 + }""") + + def test_no_helpers(self, tmp_path: Path) -> None: + result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path) + assert result.helper_functions == [] + assert result.testgen_helper_fqns == [] + + def test_empty_read_only_context(self, tmp_path: Path) -> None: + result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path) + assert result.read_only_context_code == "" + + def test_preexisting_objects_empty(self, tmp_path: Path) -> None: + result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path) + assert result.preexisting_objects == set() + + +# --------------------------------------------------------------------------- +# Tests: init function and globals in context +# --------------------------------------------------------------------------- + + +class TestBuildContextWithInit: + """Target: Process() — source has init(), global vars, consts, struct.""" + + def test_init_in_read_only_context(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + assert "func init()" in result.read_only_context_code + + def test_referenced_globals_in_read_only_context(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + assert "globalCache" in result.read_only_context_code + assert "mu" in result.read_only_context_code + + def test_referenced_const_in_read_only_context(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + assert "MaxRetries" in result.read_only_context_code + + def test_referenced_struct_in_read_only_context(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + assert "type Config struct" in result.read_only_context_code + + def test_init_not_in_helpers(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + helper_names = [h.only_function_name for h in result.helper_functions] + assert "init" not in helper_names + + def test_init_not_in_read_writable_code(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + code = result.read_writable_code.code_strings[0].code + assert "func init()" not in code + + def test_full_read_only_context_string(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + expected = dedent("""\ + var ( + \tglobalCache map[string]int + \tmu sync.Mutex + ) + + const MaxRetries = 5 + + type Config struct { + \tName string + \tMax int + } + + func init() { + \tglobalCache = make(map[string]int) + \tglobalCache["default"] = 0 + \tdefaultCfg := Config{Name: "prod", Max: MaxRetries} + \t_ = defaultCfg + \tmu.Lock() + \tmu.Unlock() + }""") + assert result.read_only_context_code == expected + + +class TestBuildContextNoInit: + """Source without init — verify no init context is added.""" + + def test_no_init_no_extra_read_only(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert "func init()" not in result.read_only_context_code + + def test_no_init_read_only_empty_for_function(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.read_only_context_code == "" + + +# --------------------------------------------------------------------------- +# Tests: subdirectory / relative path handling +# --------------------------------------------------------------------------- + + +class TestBuildContextSubdirectory: + """Source file in a pkg/ subdirectory.""" + + def test_relative_path_includes_subdir(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + pkg = root / "pkg" + pkg.mkdir() + source_file = (pkg / "calc.go").resolve() + source_file.write_text(SIMPLE_SOURCE, encoding="utf-8") + + func = FunctionToOptimize(function_name="Double", file_path=source_file, language="go") + ctx = extract_code_context(func, root) + result = _build_optimization_context(ctx, source_file, "go", root) + + assert result.read_writable_code.code_strings[0].file_path == Path("pkg/calc.go") + + +# --------------------------------------------------------------------------- +# Tests: hashing +# --------------------------------------------------------------------------- + + +class TestBuildContextHashing: + def test_hash_is_sha256_of_flat(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + expected_hash = hashlib.sha256(result.read_writable_code.flat.encode("utf-8")).hexdigest() + assert result.hashing_code_context_hash == expected_hash + + def test_hashing_code_equals_flat(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.hashing_code_context == result.read_writable_code.flat + + def test_different_targets_different_hashes(self, tmp_path: Path) -> None: + dir_a = tmp_path / "a" + dir_a.mkdir() + dir_b = tmp_path / "b" + dir_b.mkdir() + + r1 = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", dir_a) + r2 = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Greet", dir_b) + + assert r1.hashing_code_context_hash != r2.hashing_code_context_hash + + +# --------------------------------------------------------------------------- +# Tests: testgen context +# --------------------------------------------------------------------------- + + +class TestBuildContextTestgen: + def test_testgen_matches_read_writable(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.testgen_context.markdown == result.read_writable_code.markdown + + +# --------------------------------------------------------------------------- +# Tests: token limit enforcement +# --------------------------------------------------------------------------- + + +class TestBuildContextTokenLimits: + def test_exceeds_optim_token_limit(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + source_file = (root / "big.go").resolve() + huge_code = "package big\n\nfunc Big() string {\n\treturn " + '"x" + ' * 100000 + '"x"\n}\n' + source_file.write_text(huge_code, encoding="utf-8") + + func = FunctionToOptimize(function_name="Big", file_path=source_file, language="go") + ctx = extract_code_context(func, root) + + with pytest.raises(ValueError, match="Read-writable code has exceeded token limit"): + _build_optimization_context(ctx, source_file, "go", root, optim_token_limit=10) + + def test_exceeds_testgen_token_limit(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + source_file = (root / "big.go").resolve() + huge_code = "package big\n\nfunc Big() string {\n\treturn " + '"x" + ' * 100000 + '"x"\n}\n' + source_file.write_text(huge_code, encoding="utf-8") + + func = FunctionToOptimize(function_name="Big", file_path=source_file, language="go") + ctx = extract_code_context(func, root) + + with pytest.raises(ValueError, match="Testgen code context has exceeded token limit"): + _build_optimization_context( + ctx, source_file, "go", root, optim_token_limit=1_000_000, testgen_token_limit=10 + ) + + +# --------------------------------------------------------------------------- +# Tests: GoSupport wiring +# --------------------------------------------------------------------------- + + +class TestGoSupportFunctionOptimizerClass: + def test_returns_go_function_optimizer(self) -> None: + from codeflash.languages.golang.function_optimizer import GoFunctionOptimizer + from codeflash.languages.golang.support import GoSupport + + support = GoSupport() + assert support.function_optimizer_class is GoFunctionOptimizer diff --git a/tests/test_languages/test_golang/test_instrumentation.py b/tests/test_languages/test_golang/test_instrumentation.py new file mode 100644 index 000000000..182733ab8 --- /dev/null +++ b/tests/test_languages/test_golang/test_instrumentation.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +from codeflash.languages.golang.instrumentation import _test_matches_target, convert_tests_to_benchmarks + +SIMPLE_TEST = """\ +package sample + +import "testing" + +func TestAdd(t *testing.T) { +\tgot := Add(1, 2) +\tif got != 3 { +\t\tt.Errorf("Add(1, 2) = %d, want 3", got) +\t} +} +""" + +TEST_WITH_SUBTESTS = """\ +package sample + +import "testing" + +func TestBubbleSort_BasicCases(t *testing.T) { +\ttests := []struct { +\t\tname string +\t\tinput []int +\t\twant []int +\t}{ +\t\t{"sorted", []int{1, 2, 3}, []int{1, 2, 3}}, +\t} +\tfor _, tt := range tests { +\t\tt.Run(tt.name, func(t *testing.T) { +\t\t\tgot := BubbleSort(tt.input) +\t\t\tif len(got) != len(tt.want) { +\t\t\t\tt.Errorf("wrong length") +\t\t\t} +\t\t}) +\t} +} +""" + +MULTIPLE_TESTS = """\ +package sample + +import "testing" + +func TestFoo(t *testing.T) { +\tFoo() +} + +func TestBar(t *testing.T) { +\tBar() +} +""" + +BENCHMARK_ONLY = """\ +package sample + +import "testing" + +func BenchmarkFoo(b *testing.B) { +\tfor i := 0; i < b.N; i++ { +\t\tFoo() +\t} +} +""" + +TEST_WITH_HELPER = """\ +package sample + +import "testing" + +func equalSlices(t *testing.T, got, want []int) { +\tif len(got) != len(want) { +\t\tt.Fatalf("length mismatch") +\t} +} + +func TestBFS(t *testing.T) { +\tgot := BFS(graph, 0) +\tequalSlices(t, got, []int{0, 1, 2}) +} +""" + +TEST_WITH_PARALLEL = """\ +package sample + +import "testing" + +func TestFoo(t *testing.T) { +\tt.Parallel() +\tFoo() +} + +func TestBar(t *testing.T) { +\tt.Helper() +\tt.Parallel() +\tBar() +} +""" + + +class TestMatchesTarget: + def test_exact_match(self) -> None: + assert _test_matches_target("TestBFS", "BFS") is True + + def test_prefix_segment_match(self) -> None: + assert _test_matches_target("TestBFS_BasicCases", "BFS") is True + + def test_suffix_segment_match(self) -> None: + assert _test_matches_target("TestGraph_BFS", "BFS") is True + + def test_no_match_substring(self) -> None: + assert _test_matches_target("TestBFSHelper", "BFS") is False + + def test_no_match_different_function(self) -> None: + assert _test_matches_target("TestDFS", "BFS") is False + + def test_multi_underscore(self) -> None: + assert _test_matches_target("TestBFS_Large_Graph", "BFS") is True + + +class TestConvertTestsToBenchmarks: + def test_simple_test(self) -> None: + result = convert_tests_to_benchmarks(SIMPLE_TEST, "Add") + assert "func BenchmarkAdd(" in result + assert "*testing.B)" in result + assert "for i := 0; i < " in result + assert ".N; i++ {" in result + assert "func TestAdd(" not in result + + def test_subtests_converted(self) -> None: + result = convert_tests_to_benchmarks(TEST_WITH_SUBTESTS, "BubbleSort") + assert "func BenchmarkBubbleSort_BasicCases(" in result + assert "*testing.T" not in result + + def test_multiple_functions_filtered(self) -> None: + result = convert_tests_to_benchmarks(MULTIPLE_TESTS, "Foo") + assert "func BenchmarkFoo(" in result + assert "func BenchmarkBar(" not in result + assert "func TestFoo(" not in result + assert "func TestBar(" not in result + + def test_multiple_functions_no_filter(self) -> None: + result = convert_tests_to_benchmarks(MULTIPLE_TESTS, "") + assert "func BenchmarkFoo(" in result + assert "func BenchmarkBar(" in result + assert "func TestFoo(" not in result + assert "func TestBar(" not in result + + def test_empty_source(self) -> None: + assert convert_tests_to_benchmarks("", "Foo") == "" + + def test_no_test_functions(self) -> None: + result = convert_tests_to_benchmarks(BENCHMARK_ONLY, "Foo") + assert result == BENCHMARK_ONLY + + def test_package_preserved(self) -> None: + result = convert_tests_to_benchmarks(SIMPLE_TEST, "Add") + assert result.startswith("package sample") + + def test_import_preserved(self) -> None: + result = convert_tests_to_benchmarks(SIMPLE_TEST, "Add") + assert 'import "testing"' in result + + def test_helper_functions_converted(self) -> None: + result = convert_tests_to_benchmarks(TEST_WITH_HELPER, "BFS") + assert "func BenchmarkBFS(" in result + assert "*testing.T" not in result + assert "equalSlices" in result + assert "*testing.B" in result + + def test_parallel_removed(self) -> None: + result = convert_tests_to_benchmarks(TEST_WITH_PARALLEL, "Foo") + assert ".Parallel()" not in result + assert ".Helper()" not in result + assert "func BenchmarkFoo(" in result + assert "func BenchmarkBar(" not in result diff --git a/tests/test_languages/test_golang/test_parse.py b/tests/test_languages/test_golang/test_parse.py new file mode 100644 index 000000000..92dd5c97e --- /dev/null +++ b/tests/test_languages/test_golang/test_parse.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +import json +import subprocess +from pathlib import Path +from unittest.mock import MagicMock + +from codeflash.languages.golang.parse import BENCHMARK_RE, parse_go_test_output +from codeflash.models.models import TestFile, TestFiles +from codeflash.models.test_type import TestType + + +def _make_test_config(tmp_path: Path) -> MagicMock: + cfg = MagicMock() + cfg.tests_project_rootdir = tmp_path + cfg.test_framework = "go-test" + return cfg + + +def _make_test_files(tmp_path: Path, filenames: list[str] | None = None, test_type: TestType = TestType.GENERATED_REGRESSION) -> TestFiles: + if filenames is None: + filenames = ["calc_test.go"] + files = [] + for name in filenames: + path = (tmp_path / name).resolve() + path.write_text("package calc\n", encoding="utf-8") + files.append( + TestFile( + instrumented_behavior_file_path=path, + test_type=test_type, + ) + ) + return TestFiles(test_files=files) + + +def _write_jsonl(path: Path, events: list[dict]) -> None: + path.write_text("\n".join(json.dumps(e) for e in events) + "\n", encoding="utf-8") + + +class TestBenchmarkRegex: + def test_basic_benchmark_line(self) -> None: + line = "BenchmarkAdd-8 \t 1000000\t 1234 ns/op\t 56 B/op\t 2 allocs/op" + m = BENCHMARK_RE.search(line) + assert m is not None + assert m.group(1) == "BenchmarkAdd" + assert m.group(2) == "1000000" + assert m.group(3) == "1234" + assert m.group(4) == "56" + assert m.group(5) == "2" + + def test_benchmark_without_mem(self) -> None: + line = "BenchmarkSort 5000 300000 ns/op" + m = BENCHMARK_RE.search(line) + assert m is not None + assert m.group(1) == "BenchmarkSort" + assert m.group(4) is None + assert m.group(5) is None + + def test_benchmark_with_float_ns(self) -> None: + line = "BenchmarkFib-16 100000 12345.67 ns/op" + m = BENCHMARK_RE.search(line) + assert m is not None + assert m.group(3) == "12345.67" + + def test_non_benchmark_line(self) -> None: + line = "=== RUN TestAdd" + m = BENCHMARK_RE.search(line) + assert m is None + + +class TestParseGoTestOutputBehavioral: + def test_all_passing(self, tmp_path: Path) -> None: + events = [ + {"Time": "2024-01-01T00:00:00Z", "Action": "run", "Package": "example/calc", "Test": "TestAdd"}, + {"Time": "2024-01-01T00:00:00Z", "Action": "output", "Package": "example/calc", "Test": "TestAdd", "Output": "=== RUN TestAdd\n"}, + {"Time": "2024-01-01T00:00:00Z", "Action": "output", "Package": "example/calc", "Test": "TestAdd", "Output": "--- PASS: TestAdd (0.00s)\n"}, + {"Time": "2024-01-01T00:00:00Z", "Action": "pass", "Package": "example/calc", "Test": "TestAdd", "Elapsed": 0.001}, + {"Time": "2024-01-01T00:00:00Z", "Action": "run", "Package": "example/calc", "Test": "TestSub"}, + {"Time": "2024-01-01T00:00:00Z", "Action": "pass", "Package": "example/calc", "Test": "TestSub", "Elapsed": 0.002}, + {"Time": "2024-01-01T00:00:00Z", "Action": "pass", "Package": "example/calc"}, + ] + + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert len(results.test_results) == 2 + + by_name = {r.id.test_function_name: r for r in results.test_results} + assert by_name["TestAdd"].did_pass is True + assert by_name["TestAdd"].runtime == 1_000_000 + assert by_name["TestSub"].did_pass is True + assert by_name["TestSub"].runtime == 2_000_000 + + def test_with_failure(self, tmp_path: Path) -> None: + events = [ + {"Action": "run", "Package": "example/calc", "Test": "TestAdd"}, + {"Action": "output", "Package": "example/calc", "Test": "TestAdd", "Output": "got 4, want 5\n"}, + {"Action": "fail", "Package": "example/calc", "Test": "TestAdd", "Elapsed": 0.01}, + ] + + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert len(results.test_results) == 1 + assert results.test_results[0].did_pass is False + assert "got 4, want 5" in results.test_results[0].stdout + + def test_mixed_pass_fail(self, tmp_path: Path) -> None: + events = [ + {"Action": "pass", "Package": "p", "Test": "TestA", "Elapsed": 0.001}, + {"Action": "fail", "Package": "p", "Test": "TestB", "Elapsed": 0.002}, + {"Action": "pass", "Package": "p", "Test": "TestC", "Elapsed": 0.003}, + ] + + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + by_name = {r.id.test_function_name: r for r in results.test_results} + assert by_name["TestA"].did_pass is True + assert by_name["TestB"].did_pass is False + assert by_name["TestC"].did_pass is True + + def test_empty_file(self, tmp_path: Path) -> None: + json_path = (tmp_path / "empty.jsonl").resolve() + json_path.write_text("", encoding="utf-8") + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert len(results.test_results) == 0 + + def test_missing_file_falls_back_to_run_result(self, tmp_path: Path) -> None: + json_path = (tmp_path / "nonexistent.jsonl").resolve() + events = [ + {"Action": "pass", "Package": "p", "Test": "TestX", "Elapsed": 0.005}, + ] + stdout = "\n".join(json.dumps(e) for e in events) + run_result = subprocess.CompletedProcess(args=[], returncode=0, stdout=stdout, stderr="") + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg, run_result) + assert len(results.test_results) == 1 + assert results.test_results[0].id.test_function_name == "TestX" + + def test_invalid_json_lines_skipped(self, tmp_path: Path) -> None: + content = 'not json\n{"Action":"pass","Package":"p","Test":"TestOK","Elapsed":0.001}\nalso bad\n' + json_path = (tmp_path / "results.jsonl").resolve() + json_path.write_text(content, encoding="utf-8") + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert len(results.test_results) == 1 + assert results.test_results[0].did_pass is True + + def test_test_type_from_test_files(self, tmp_path: Path) -> None: + test_files = _make_test_files(tmp_path, test_type=TestType.EXISTING_UNIT_TEST) + events = [ + {"Action": "pass", "Package": "p", "Test": "TestFoo", "Elapsed": 0.001}, + ] + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert results.test_results[0].test_type == TestType.EXISTING_UNIT_TEST + + def test_framework_is_go_test(self, tmp_path: Path) -> None: + events = [ + {"Action": "pass", "Package": "p", "Test": "TestBar", "Elapsed": 0.001}, + ] + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert results.test_results[0].test_framework == "go-test" + + def test_package_level_events_ignored(self, tmp_path: Path) -> None: + events = [ + {"Action": "pass", "Package": "p", "Test": "TestOK", "Elapsed": 0.001}, + {"Action": "pass", "Package": "p", "Elapsed": 0.5}, + ] + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert len(results.test_results) == 1 + + +class TestParseGoTestOutputBenchmark: + def test_benchmark_parsing(self, tmp_path: Path) -> None: + events = [ + {"Action": "run", "Package": "p", "Test": "BenchmarkAdd"}, + {"Action": "output", "Package": "p", "Test": "BenchmarkAdd", "Output": "BenchmarkAdd-8 \t 1000000\t 1234 ns/op\t 56 B/op\t 2 allocs/op\n"}, + {"Action": "pass", "Package": "p", "Test": "BenchmarkAdd", "Elapsed": 1.5}, + ] + json_path = (tmp_path / "bench.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert len(results.test_results) == 1 + result = results.test_results[0] + assert result.did_pass is True + assert result.runtime == 1234 + + def test_benchmark_overrides_elapsed(self, tmp_path: Path) -> None: + events = [ + {"Action": "output", "Package": "p", "Test": "BenchmarkSort", "Output": "BenchmarkSort 5000 300000 ns/op\n"}, + {"Action": "pass", "Package": "p", "Test": "BenchmarkSort", "Elapsed": 2.0}, + ] + json_path = (tmp_path / "bench.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert results.test_results[0].runtime == 300000 + + def test_mixed_tests_and_benchmarks(self, tmp_path: Path) -> None: + events = [ + {"Action": "pass", "Package": "p", "Test": "TestAdd", "Elapsed": 0.001}, + {"Action": "output", "Package": "p", "Test": "BenchmarkAdd", "Output": "BenchmarkAdd-8 1000000 500 ns/op\n"}, + {"Action": "pass", "Package": "p", "Test": "BenchmarkAdd", "Elapsed": 1.0}, + ] + json_path = (tmp_path / "mixed.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + by_name = {r.id.test_function_name: r for r in results.test_results} + + assert by_name["TestAdd"].runtime == 1_000_000 + assert by_name["BenchmarkAdd"].runtime == 500 + + +class TestParseGoTestOutputInvocationId: + def test_invocation_id_fields(self, tmp_path: Path) -> None: + events = [ + {"Action": "pass", "Package": "example/calc", "Test": "TestAdd", "Elapsed": 0.001}, + ] + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + inv = results.test_results[0] + assert inv.id.test_module_path == "example/calc" + assert inv.id.test_class_name is None + assert inv.id.test_function_name == "TestAdd" + assert inv.loop_index == 1 + + def test_unique_invocation_loop_id_stable(self, tmp_path: Path) -> None: + events = [ + {"Action": "pass", "Package": "p", "Test": "TestA", "Elapsed": 0.001}, + ] + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results1 = parse_go_test_output(json_path, test_files, cfg) + results2 = parse_go_test_output(json_path, test_files, cfg) + + id1 = results1.test_results[0].unique_invocation_loop_id + id2 = results2.test_results[0].unique_invocation_loop_id + assert id1 == id2 + + +class TestParseGoTestOutputComparison: + def test_behavioral_comparison_same_results(self, tmp_path: Path) -> None: + from codeflash.verification.equivalence import compare_test_results + + events = [ + {"Action": "pass", "Package": "p", "Test": "TestAdd", "Elapsed": 0.001}, + {"Action": "pass", "Package": "p", "Test": "TestSub", "Elapsed": 0.002}, + ] + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + original = parse_go_test_output(json_path, test_files, cfg) + candidate = parse_go_test_output(json_path, test_files, cfg) + + are_equal, diffs = compare_test_results(original, candidate, pass_fail_only=True) + assert are_equal is True + assert diffs == [] + + def test_behavioral_comparison_different_results(self, tmp_path: Path) -> None: + from codeflash.verification.equivalence import compare_test_results + + original_events = [ + {"Action": "pass", "Package": "p", "Test": "TestAdd", "Elapsed": 0.001}, + ] + candidate_events = [ + {"Action": "fail", "Package": "p", "Test": "TestAdd", "Elapsed": 0.001}, + ] + orig_path = (tmp_path / "orig.jsonl").resolve() + cand_path = (tmp_path / "cand.jsonl").resolve() + _write_jsonl(orig_path, original_events) + _write_jsonl(cand_path, candidate_events) + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + original = parse_go_test_output(orig_path, test_files, cfg) + candidate = parse_go_test_output(cand_path, test_files, cfg) + + are_equal, diffs = compare_test_results(original, candidate, pass_fail_only=True) + assert are_equal is False + assert len(diffs) == 1 + + def test_empty_results_not_equal(self, tmp_path: Path) -> None: + from codeflash.models.models import TestResults + from codeflash.verification.equivalence import compare_test_results + + are_equal, _diffs = compare_test_results(TestResults(), TestResults(), pass_fail_only=True) + assert are_equal is False diff --git a/tests/test_languages/test_golang/test_parser.py b/tests/test_languages/test_golang/test_parser.py new file mode 100644 index 000000000..2179e92db --- /dev/null +++ b/tests/test_languages/test_golang/test_parser.py @@ -0,0 +1,340 @@ +from __future__ import annotations + +from codeflash.languages.golang.parser import GoAnalyzer + +GO_SOURCE = """\ +package calculator + +import "math" + +// Add returns the sum of two integers. +func Add(a, b int) int { + return a + b +} + +func subtract(a, b int) int { + return a - b +} + +func noReturn() { + println("hello") +} + +type Calculator struct { + Result float64 +} + +// AddFloat adds a value. +func (c *Calculator) AddFloat(val float64) float64 { + c.Result += val + return c.Result +} + +func (c Calculator) GetResult() float64 { + return c.Result +} + +// Reset zeroes the calculator. +func (c *Calculator) Reset() { + c.Result = 0 +} + +type Adder interface { + Add(a, b int) int +} +""" + + +class TestGoAnalyzerFunctions: + def test_find_functions(self) -> None: + analyzer = GoAnalyzer() + functions = analyzer.find_functions(GO_SOURCE) + names = [f.name for f in functions] + assert "Add" in names + assert "subtract" in names + assert "noReturn" in names + + def test_exported_detection(self) -> None: + analyzer = GoAnalyzer() + functions = analyzer.find_functions(GO_SOURCE) + by_name = {f.name: f for f in functions} + assert by_name["Add"].is_exported is True + assert by_name["subtract"].is_exported is False + + def test_return_type_detection(self) -> None: + analyzer = GoAnalyzer() + functions = analyzer.find_functions(GO_SOURCE) + by_name = {f.name: f for f in functions} + assert by_name["Add"].has_return_type is True + assert by_name["noReturn"].has_return_type is False + + def test_doc_comment_detection(self) -> None: + analyzer = GoAnalyzer() + functions = analyzer.find_functions(GO_SOURCE) + by_name = {f.name: f for f in functions} + assert by_name["Add"].doc_start_line is not None + assert by_name["subtract"].doc_start_line is None + + def test_line_numbers(self) -> None: + analyzer = GoAnalyzer() + functions = analyzer.find_functions(GO_SOURCE) + by_name = {f.name: f for f in functions} + add_func = by_name["Add"] + assert add_func.starting_line == 6 + assert add_func.ending_line == 8 + + +class TestGoAnalyzerMethods: + def test_find_methods(self) -> None: + analyzer = GoAnalyzer() + methods = analyzer.find_methods(GO_SOURCE) + names = [m.name for m in methods] + assert "AddFloat" in names + assert "GetResult" in names + assert "Reset" in names + + def test_receiver_detection(self) -> None: + analyzer = GoAnalyzer() + methods = analyzer.find_methods(GO_SOURCE) + by_name = {m.name: m for m in methods} + assert by_name["AddFloat"].receiver_name == "Calculator" + assert by_name["AddFloat"].receiver_is_pointer is True + assert by_name["GetResult"].receiver_name == "Calculator" + assert by_name["GetResult"].receiver_is_pointer is False + + def test_method_doc_comment(self) -> None: + analyzer = GoAnalyzer() + methods = analyzer.find_methods(GO_SOURCE) + by_name = {m.name: m for m in methods} + assert by_name["AddFloat"].doc_start_line is not None + assert by_name["Reset"].doc_start_line is not None + assert by_name["GetResult"].doc_start_line is None + + def test_method_exported(self) -> None: + analyzer = GoAnalyzer() + methods = analyzer.find_methods(GO_SOURCE) + for m in methods: + assert m.is_exported is True + + +class TestGoAnalyzerStructs: + def test_find_structs(self) -> None: + analyzer = GoAnalyzer() + structs = analyzer.find_structs(GO_SOURCE) + assert len(structs) == 1 + assert structs[0].name == "Calculator" + assert len(structs[0].fields) > 0 + + def test_struct_field_content(self) -> None: + analyzer = GoAnalyzer() + structs = analyzer.find_structs(GO_SOURCE) + field_text = " ".join(structs[0].fields) + assert "Result" in field_text + assert "float64" in field_text + + +class TestGoAnalyzerInterfaces: + def test_find_interfaces(self) -> None: + analyzer = GoAnalyzer() + interfaces = analyzer.find_interfaces(GO_SOURCE) + assert len(interfaces) == 1 + assert interfaces[0].name == "Adder" + assert len(interfaces[0].methods) > 0 + + +class TestGoAnalyzerImports: + def test_find_imports(self) -> None: + analyzer = GoAnalyzer() + imports = analyzer.find_imports(GO_SOURCE) + assert len(imports) == 1 + assert imports[0].path == "math" + assert imports[0].alias is None + + def test_multi_import(self) -> None: + source = '''\ +package main + +import ( + "fmt" + "os" + str "strings" +) + +func Main() string { + return "hello" +} +''' + analyzer = GoAnalyzer() + imports = analyzer.find_imports(source) + paths = {i.path for i in imports} + assert paths == {"fmt", "os", "strings"} + aliases = {i.path: i.alias for i in imports} + assert aliases["strings"] == "str" + assert aliases["fmt"] is None + + +class TestGoAnalyzerPackage: + def test_find_package_name(self) -> None: + analyzer = GoAnalyzer() + assert analyzer.find_package_name(GO_SOURCE) == "calculator" + + def test_find_package_name_main(self) -> None: + analyzer = GoAnalyzer() + assert analyzer.find_package_name("package main\n\nfunc main() {}") == "main" + + +class TestGoAnalyzerSyntax: + def test_valid_syntax(self) -> None: + analyzer = GoAnalyzer() + assert analyzer.validate_syntax(GO_SOURCE) is True + + def test_invalid_syntax(self) -> None: + analyzer = GoAnalyzer() + assert analyzer.validate_syntax("func {{{invalid") is False + + +class TestGoAnalyzerExtract: + def test_extract_function_source(self) -> None: + analyzer = GoAnalyzer() + source = analyzer.extract_function_source(GO_SOURCE, "Add") + assert source is not None + assert "func Add" in source + assert "return a + b" in source + + def test_extract_function_source_with_doc(self) -> None: + analyzer = GoAnalyzer() + source = analyzer.extract_function_source(GO_SOURCE, "Add") + assert source is not None + assert "// Add returns" in source + + def test_extract_method_source(self) -> None: + analyzer = GoAnalyzer() + source = analyzer.extract_function_source(GO_SOURCE, "AddFloat", receiver_type="Calculator") + assert source is not None + assert "func (c *Calculator) AddFloat" in source + + def test_extract_nonexistent(self) -> None: + analyzer = GoAnalyzer() + assert analyzer.extract_function_source(GO_SOURCE, "DoesNotExist") is None + + +GLOBALS_SOURCE = """\ +package server + +import "sync" + +var ( +\tglobalCache map[string]int +\tmu sync.Mutex +) + +var singleVar = 42 + +const MaxRetries = 5 + +const ( +\tDefaultName = "prod" +\tTimeout = 30 +) + +type Config struct { +\tName string +\tMax int +} + +func init() { +\tglobalCache = make(map[string]int) +\tglobalCache["default"] = 0 +\tdefaultCfg := Config{Name: DefaultName, Max: MaxRetries} +\t_ = defaultCfg +\tmu.Lock() +\tmu.Unlock() +} + +func Process() int { +\treturn singleVar + MaxRetries +} +""" + + +class TestGoAnalyzerGlobalDeclarations: + def test_find_var_group(self) -> None: + analyzer = GoAnalyzer() + decls = analyzer.find_global_declarations(GLOBALS_SOURCE) + var_decls = [d for d in decls if d.kind == "var"] + all_names = [name for d in var_decls for name in d.names] + assert "globalCache" in all_names + assert "mu" in all_names + assert "singleVar" in all_names + + def test_find_const_group(self) -> None: + analyzer = GoAnalyzer() + decls = analyzer.find_global_declarations(GLOBALS_SOURCE) + const_decls = [d for d in decls if d.kind == "const"] + all_names = [name for d in const_decls for name in d.names] + assert "MaxRetries" in all_names + assert "DefaultName" in all_names + assert "Timeout" in all_names + + def test_grouped_var_names_together(self) -> None: + analyzer = GoAnalyzer() + decls = analyzer.find_global_declarations(GLOBALS_SOURCE) + var_group = next(d for d in decls if "globalCache" in d.names) + assert var_group.names == ("globalCache", "mu") + + def test_single_var(self) -> None: + analyzer = GoAnalyzer() + decls = analyzer.find_global_declarations(GLOBALS_SOURCE) + single = next(d for d in decls if "singleVar" in d.names) + assert single.kind == "var" + assert single.source_code == "var singleVar = 42" + + def test_const_group_source_code(self) -> None: + analyzer = GoAnalyzer() + decls = analyzer.find_global_declarations(GLOBALS_SOURCE) + group = next(d for d in decls if "DefaultName" in d.names) + assert "DefaultName" in group.source_code + assert "Timeout" in group.source_code + + def test_no_globals_in_clean_source(self) -> None: + analyzer = GoAnalyzer() + decls = analyzer.find_global_declarations("package main\n\nfunc main() {}\n") + assert decls == [] + + +class TestGoAnalyzerCollectBodyIdentifiers: + def test_init_body_identifiers(self) -> None: + analyzer = GoAnalyzer() + ids = analyzer.collect_body_identifiers(GLOBALS_SOURCE, "init") + assert "globalCache" in ids + assert "Config" in ids + assert "DefaultName" in ids + assert "MaxRetries" in ids + assert "mu" in ids + + def test_process_body_identifiers(self) -> None: + analyzer = GoAnalyzer() + ids = analyzer.collect_body_identifiers(GLOBALS_SOURCE, "Process") + assert "singleVar" in ids + assert "MaxRetries" in ids + + def test_nonexistent_function_returns_empty(self) -> None: + analyzer = GoAnalyzer() + ids = analyzer.collect_body_identifiers(GLOBALS_SOURCE, "DoesNotExist") + assert ids == set() + + def test_method_body_identifiers(self) -> None: + source = """\ +package calc + +type Calc struct{ val int } + +var offset = 10 + +func (c *Calc) Compute() int { +\treturn c.val + offset +} +""" + analyzer = GoAnalyzer() + ids = analyzer.collect_body_identifiers(source, "Compute", receiver_type="Calc") + assert "offset" in ids diff --git a/tests/test_languages/test_golang/test_replacement.py b/tests/test_languages/test_golang/test_replacement.py new file mode 100644 index 000000000..5cd444aab --- /dev/null +++ b/tests/test_languages/test_golang/test_replacement.py @@ -0,0 +1,658 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.golang.replacement import add_global_declarations, remove_test_functions, replace_function +from codeflash.models.function_types import FunctionParent, FunctionToOptimize + + +class TestReplaceFunction: + def test_replace_basic_function(self) -> None: + source = "package calc\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n\nfunc Subtract(a, b int) int {\n\treturn a - b\n}\n" + func = FunctionToOptimize(function_name="Add", file_path=Path("/project/calc.go"), language="go") + new_body = "func Add(a, b int) int {\n\tresult := a + b\n\treturn result\n}" + result = replace_function(source, func, new_body) + expected = "package calc\n\nfunc Add(a, b int) int {\n\tresult := a + b\n\treturn result\n}\n\nfunc Subtract(a, b int) int {\n\treturn a - b\n}\n" + assert result == expected + + def test_replace_function_with_doc_comment(self) -> None: + source = "package calc\n\n// Add returns the sum.\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + func = FunctionToOptimize(function_name="Add", file_path=Path("/project/calc.go"), language="go") + new_body = "// Add returns an optimized sum.\nfunc Add(a, b int) int {\n\treturn a + b\n}" + result = replace_function(source, func, new_body) + expected = "package calc\n\n// Add returns an optimized sum.\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + assert result == expected + + def test_replace_method(self) -> None: + source = ( + "package calc\n\n" + "type Calc struct {\n\tResult float64\n}\n\n" + "// AddFloat adds a value.\n" + "func (c *Calc) AddFloat(val float64) float64 {\n\tc.Result += val\n\treturn c.Result\n}\n\n" + "func (c Calc) GetResult() float64 {\n\treturn c.Result\n}\n" + ) + func = FunctionToOptimize( + function_name="AddFloat", + file_path=Path("/project/calc.go"), + parents=[FunctionParent(name="Calc", type="StructDef")], + language="go", + is_method=True, + ) + new_body = "// AddFloat adds a value (optimized).\nfunc (c *Calc) AddFloat(val float64) float64 {\n\tc.Result = c.Result + val\n\treturn c.Result\n}" + result = replace_function(source, func, new_body) + expected = ( + "package calc\n\n" + "type Calc struct {\n\tResult float64\n}\n\n" + "// AddFloat adds a value (optimized).\n" + "func (c *Calc) AddFloat(val float64) float64 {\n\tc.Result = c.Result + val\n\treturn c.Result\n}\n\n" + "func (c Calc) GetResult() float64 {\n\treturn c.Result\n}\n" + ) + assert result == expected + + def test_replace_nonexistent_returns_original(self) -> None: + source = "package calc\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + func = FunctionToOptimize(function_name="Missing", file_path=Path("/project/calc.go"), language="go") + result = replace_function(source, func, "func Missing() {}") + assert result == source + + def test_replace_preserves_surrounding_code(self) -> None: + source = ( + "package calc\n\n" + "var version = 1\n\n" + "func Add(a, b int) int {\n\treturn a + b\n}\n\n" + "func Subtract(a, b int) int {\n\treturn a - b\n}\n" + ) + func = FunctionToOptimize(function_name="Add", file_path=Path("/project/calc.go"), language="go") + new_body = "func Add(a, b int) int {\n\treturn b + a\n}" + result = replace_function(source, func, new_body) + expected = ( + "package calc\n\n" + "var version = 1\n\n" + "func Add(a, b int) int {\n\treturn b + a\n}\n\n" + "func Subtract(a, b int) int {\n\treturn a - b\n}\n" + ) + assert result == expected + + +class TestAddGlobalDeclarationsImports: + def test_add_import_to_existing_block(self) -> None: + original = 'package calc\n\nimport (\n\t"fmt"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + optimized = 'package calc\n\nimport (\n\t"fmt"\n\t"math"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + result = add_global_declarations(optimized, original) + expected = 'package calc\n\nimport (\n\t"fmt"\n\t"math"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + assert result == expected + + def test_add_aliased_import(self) -> None: + original = 'package calc\n\nimport (\n\t"fmt"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + optimized = 'package calc\n\nimport (\n\t"fmt"\n\tstr "strings"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + result = add_global_declarations(optimized, original) + expected = 'package calc\n\nimport (\n\t"fmt"\n\tstr "strings"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + assert result == expected + + def test_add_import_when_no_existing_imports(self) -> None: + original = "package calc\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + optimized = 'package calc\n\nimport "math"\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + result = add_global_declarations(optimized, original) + expected = 'package calc\nimport (\n\t"math"\n)\n\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + assert result == expected + + def test_no_new_imports_returns_unchanged(self) -> None: + source = "package calc\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + result = add_global_declarations(source, source) + assert result == source + + +class TestAddGlobalDeclarationsNewVar: + def test_add_single_new_var(self) -> None: + original = ( + "package calc\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "var cache = make(map[int]int)\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n" + "var cache = make(map[int]int)\n\n" + "\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + assert result == expected + + def test_add_grouped_var_block(self) -> None: + original = ( + "package server\n\n" + 'import "fmt"\n\n' + "func Process() {\n" + "\tfmt.Println()\n" + "}\n" + ) + optimized = ( + "package server\n\n" + 'import "fmt"\n\n' + "var (\n" + "\tcache map[string]int\n" + "\tbuffer []byte\n" + ")\n\n" + "func Process() {\n" + "\tfmt.Println()\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package server\n\n" + 'import "fmt"\n' + "var (\n" + "\tcache map[string]int\n" + "\tbuffer []byte\n" + ")\n\n" + "\n" + "func Process() {\n" + "\tfmt.Println()\n" + "}\n" + ) + assert result == expected + + def test_add_new_var_preserves_existing_var(self) -> None: + original = ( + "package calc\n\n" + "var version = 1\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "var version = 1\n\n" + "var cache = make(map[int]int)\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n\n" + "var version = 1\n" + "var cache = make(map[int]int)\n\n" + "\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + assert result == expected + + +class TestAddGlobalDeclarationsNewConst: + def test_add_single_new_const(self) -> None: + original = ( + "package calc\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "const maxSize = 1024\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n" + "const maxSize = 1024\n\n" + "\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + assert result == expected + + def test_add_grouped_const_block(self) -> None: + original = ( + "package calc\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "const (\n" + "\tMaxRetries = 5\n" + "\tTimeout = 30\n" + ")\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n" + "const (\n" + "\tMaxRetries = 5\n" + "\tTimeout = 30\n" + ")\n\n" + "\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + assert result == expected + + def test_add_new_const_preserves_existing_const(self) -> None: + original = ( + "package calc\n\n" + "const Pi = 3.14\n\n" + "func Area(r float64) float64 {\n" + "\treturn Pi * r * r\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "const Pi = 3.14\n\n" + "const TwoPi = 6.28\n\n" + "func Area(r float64) float64 {\n" + "\treturn Pi * r * r\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n\n" + "const Pi = 3.14\n" + "const TwoPi = 6.28\n\n" + "\n" + "func Area(r float64) float64 {\n" + "\treturn Pi * r * r\n" + "}\n" + ) + assert result == expected + + +class TestAddGlobalDeclarationsModifyVar: + def test_modify_single_var_value(self) -> None: + original = ( + "package calc\n\n" + "var bufferSize = 256\n\n" + "func Process() int {\n" + "\treturn bufferSize\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "var bufferSize = 1024\n\n" + "func Process() int {\n" + "\treturn bufferSize\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n\n" + "var bufferSize = 1024\n" + "\n" + "func Process() int {\n" + "\treturn bufferSize\n" + "}\n" + ) + assert result == expected + + def test_modify_grouped_var_block(self) -> None: + original = ( + "package server\n\n" + "var (\n" + '\thost = "localhost"\n' + "\tport = 8080\n" + ")\n\n" + "func Addr() string {\n" + "\treturn host\n" + "}\n" + ) + optimized = ( + "package server\n\n" + "var (\n" + '\thost = "0.0.0.0"\n' + "\tport = 9090\n" + ")\n\n" + "func Addr() string {\n" + "\treturn host\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package server\n\n" + "var (\n" + '\thost = "0.0.0.0"\n' + "\tport = 9090\n" + ")\n" + "\n" + "func Addr() string {\n" + "\treturn host\n" + "}\n" + ) + assert result == expected + + def test_modify_var_type(self) -> None: + original = ( + "package calc\n\n" + "var counter int\n\n" + "func Inc() {\n" + "\tcounter++\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "var counter int64\n\n" + "func Inc() {\n" + "\tcounter++\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n\n" + "var counter int64\n" + "\n" + "func Inc() {\n" + "\tcounter++\n" + "}\n" + ) + assert result == expected + + +class TestAddGlobalDeclarationsModifyConst: + def test_modify_single_const_value(self) -> None: + original = ( + "package calc\n\n" + "const MaxRetries = 3\n\n" + "func Retries() int {\n" + "\treturn MaxRetries\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "const MaxRetries = 10\n\n" + "func Retries() int {\n" + "\treturn MaxRetries\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n\n" + "const MaxRetries = 10\n" + "\n" + "func Retries() int {\n" + "\treturn MaxRetries\n" + "}\n" + ) + assert result == expected + + def test_modify_const_group(self) -> None: + original = ( + "package server\n\n" + "const (\n" + "\tDefaultTimeout = 30\n" + "\tMaxConnections = 100\n" + ")\n\n" + "func Config() int {\n" + "\treturn DefaultTimeout\n" + "}\n" + ) + optimized = ( + "package server\n\n" + "const (\n" + "\tDefaultTimeout = 60\n" + "\tMaxConnections = 500\n" + ")\n\n" + "func Config() int {\n" + "\treturn DefaultTimeout\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package server\n\n" + "const (\n" + "\tDefaultTimeout = 60\n" + "\tMaxConnections = 500\n" + ")\n" + "\n" + "func Config() int {\n" + "\treturn DefaultTimeout\n" + "}\n" + ) + assert result == expected + + +class TestAddGlobalDeclarationsMixed: + def test_new_import_and_new_var(self) -> None: + original = ( + "package calc\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + 'import "sync"\n\n' + "var mu sync.Mutex\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n" + "import (\n" + '\t"sync"\n' + ")\n" + "var mu sync.Mutex\n\n" + "\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + assert result == expected + + def test_new_and_modified_globals_together(self) -> None: + original = ( + "package server\n\n" + "var bufferSize = 256\n\n" + "func Process() int {\n" + "\treturn bufferSize\n" + "}\n" + ) + optimized = ( + "package server\n\n" + "var bufferSize = 1024\n\n" + "var cache = make(map[string]int)\n\n" + "func Process() int {\n" + "\treturn bufferSize\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package server\n\n" + "var bufferSize = 1024\n" + "var cache = make(map[string]int)\n\n" + "\n" + "func Process() int {\n" + "\treturn bufferSize\n" + "}\n" + ) + assert result == expected + + def test_no_globals_in_optimized_returns_unchanged(self) -> None: + original = ( + "package calc\n\n" + "var version = 1\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + assert result == original + + def test_identical_globals_returns_unchanged(self) -> None: + source = ( + "package calc\n\n" + "var version = 1\n\n" + "const MaxSize = 100\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(source, source) + assert result == source + + def test_full_round_trip_new_import_var_const(self) -> None: + original = ( + "package server\n\n" + "import (\n" + '\t"fmt"\n' + ")\n\n" + "const Version = 1\n\n" + "func Handle() {\n" + "\tfmt.Println()\n" + "}\n" + ) + optimized = ( + "package server\n\n" + "import (\n" + '\t"fmt"\n' + '\t"sync"\n' + ")\n\n" + "const Version = 1\n\n" + "var mu sync.Mutex\n\n" + "const MaxConns = 100\n\n" + "func Handle() {\n" + "\tmu.Lock()\n" + "\tdefer mu.Unlock()\n" + "\tfmt.Println()\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package server\n\n" + "import (\n" + '\t"fmt"\n' + '\t"sync"\n' + ")\n\n" + "const Version = 1\n" + "var mu sync.Mutex\n" + "const MaxConns = 100\n\n" + "\n" + "func Handle() {\n" + "\tfmt.Println()\n" + "}\n" + ) + assert result == expected + + +class TestRemoveTestFunctions: + def test_remove_single_function(self) -> None: + test_source = ( + "package calc\n\n" + 'import "testing"\n\n' + "func TestAdd(t *testing.T) {\n" + "\tresult := Add(2, 3)\n" + "\tif result != 5 {\n" + '\t\tt.Errorf("want 5, got %d", result)\n' + "\t}\n" + "}\n\n" + "func TestSubtract(t *testing.T) {\n" + "\tresult := Subtract(5, 3)\n" + "\tif result != 2 {\n" + '\t\tt.Errorf("want 2, got %d", result)\n' + "\t}\n" + "}\n" + ) + result = remove_test_functions(test_source, ["TestAdd"]) + expected = ( + "package calc\n\n" + 'import "testing"\n\n\n' + "func TestSubtract(t *testing.T) {\n" + "\tresult := Subtract(5, 3)\n" + "\tif result != 2 {\n" + '\t\tt.Errorf("want 2, got %d", result)\n' + "\t}\n" + "}\n" + ) + assert result == expected + + def test_remove_multiple_functions(self) -> None: + test_source = ( + "package calc\n\n" + 'import "testing"\n\n' + "// TestAdd tests addition.\n" + "func TestAdd(t *testing.T) {\n" + "\tif Add(1, 2) != 3 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n\n" + "func TestSubtract(t *testing.T) {\n" + "\tif Subtract(5, 3) != 2 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n\n" + "func TestMultiply(t *testing.T) {\n" + "\tif Multiply(2, 3) != 6 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n" + ) + result = remove_test_functions(test_source, ["TestAdd", "TestMultiply"]) + expected = ( + "package calc\n\n" + 'import "testing"\n\n\n' + "func TestSubtract(t *testing.T) {\n" + "\tif Subtract(5, 3) != 2 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n\n" + ) + assert result == expected + + def test_remove_function_with_doc_comment(self) -> None: + test_source = ( + "package calc\n\n" + 'import "testing"\n\n' + "// TestAdd tests addition.\n" + "func TestAdd(t *testing.T) {\n" + "\tif Add(1, 2) != 3 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n\n" + "func TestSubtract(t *testing.T) {\n" + "\tif Subtract(5, 3) != 2 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n" + ) + result = remove_test_functions(test_source, ["TestAdd"]) + expected = ( + "package calc\n\n" + 'import "testing"\n\n\n' + "func TestSubtract(t *testing.T) {\n" + "\tif Subtract(5, 3) != 2 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n" + ) + assert result == expected + + def test_remove_none_returns_unchanged(self) -> None: + test_source = "package calc\n\nfunc TestAdd(t *testing.T) {\n\tt.Log(\"ok\")\n}\n" + result = remove_test_functions(test_source, []) + assert result == test_source diff --git a/tests/test_languages/test_golang/test_support.py b/tests/test_languages/test_golang/test_support.py new file mode 100644 index 000000000..5c415c78e --- /dev/null +++ b/tests/test_languages/test_golang/test_support.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.golang.support import GoSupport +from codeflash.languages.language_enum import Language +from codeflash.languages.registry import get_language_support + + +class TestGoSupportProperties: + def test_language(self) -> None: + support = GoSupport() + assert support.language == Language.GO + + def test_file_extensions(self) -> None: + support = GoSupport() + assert support.file_extensions == (".go",) + + def test_default_file_extension(self) -> None: + support = GoSupport() + assert support.default_file_extension == ".go" + + def test_test_framework(self) -> None: + support = GoSupport() + assert support.test_framework == "go-test" + + def test_comment_prefix(self) -> None: + support = GoSupport() + assert support.comment_prefix == "//" + + def test_valid_test_frameworks(self) -> None: + support = GoSupport() + assert support.valid_test_frameworks == ("go-test",) + + def test_serialization_format(self) -> None: + support = GoSupport() + assert support.test_result_serialization_format == "json" + + def test_get_test_file_suffix(self) -> None: + support = GoSupport() + assert support.get_test_file_suffix() == "_test.go" + + def test_dir_excludes(self) -> None: + support = GoSupport() + assert "vendor" in support.dir_excludes + assert "testdata" in support.dir_excludes + + +class TestGoSupportRegistration: + def test_lookup_by_language_enum(self) -> None: + support = get_language_support(Language.GO) + assert support.language == Language.GO + + def test_lookup_by_extension(self) -> None: + support = get_language_support(Path("main.go")) + assert support.language == Language.GO + + def test_lookup_by_string(self) -> None: + support = get_language_support("go") + assert support.language == Language.GO + + def test_lookup_by_dot_extension(self) -> None: + support = get_language_support(".go") + assert support.language == Language.GO + + +class TestGoSupportDiscoverFunctions: + def test_discovers_functions(self) -> None: + support = GoSupport() + source = """\ +package calc + +func Add(a, b int) int { + return a + b +} + +func subtract(a, b int) int { + return a - b +} +""" + results = support.discover_functions(source, Path("/project/calc.go")) + names = [f.function_name for f in results] + assert "Add" in names + assert "subtract" in names + + def test_validate_syntax_valid(self) -> None: + support = GoSupport() + assert support.validate_syntax("package main\n\nfunc main() {}") is True + + def test_validate_syntax_invalid(self) -> None: + support = GoSupport() + assert support.validate_syntax("func {{{ invalid") is False + + +class TestGoSupportHelpers: + def test_find_test_root(self) -> None: + support = GoSupport() + root = Path("/project") + assert support.find_test_root(root) == root + + def test_get_runtime_files(self) -> None: + support = GoSupport() + assert support.get_runtime_files() == [] + + def test_instrument_for_behavior_passthrough(self) -> None: + support = GoSupport() + source = "package main\n\nfunc main() {}\n" + assert support.instrument_for_behavior(source, []) == source + + def test_instrument_for_benchmarking_passthrough(self) -> None: + support = GoSupport() + source = "package main\n\nfunc Test() {}\n" + result = support.instrument_for_benchmarking(source, None) # type: ignore[arg-type] + assert result == source + + def test_get_test_dir_for_source(self) -> None: + support = GoSupport() + source_file = Path("/project/pkg/calc.go") + result = support.get_test_dir_for_source(Path("/project"), source_file) + assert result == Path("/project/pkg") + + def test_get_module_path(self) -> None: + support = GoSupport() + source_file = Path("/project/pkg/calc.go") + result = support.get_module_path(source_file, Path("/project")) + assert result == str(source_file) + + def test_setup_test_config_returns_true(self) -> None: + support = GoSupport() + + class FakeTestCfg: + project_root_path = Path("/nonexistent") + + assert support.setup_test_config(FakeTestCfg(), Path("/file.go")) is True + + def test_prepare_module_valid(self, tmp_path: Path) -> None: + from codeflash.models.models import ValidCode + + support = GoSupport() + code = "package main\n\nfunc main() {}\n" + module_path = (tmp_path / "main.go").resolve() + result = support.prepare_module(code, module_path, tmp_path) + assert result is not None + validated, ast_node = result + assert ast_node is None + assert module_path in validated + assert isinstance(validated[module_path], ValidCode) + assert validated[module_path].source_code == code + + def test_prepare_module_invalid(self, tmp_path: Path) -> None: + support = GoSupport() + result = support.prepare_module("func {{{ invalid", (tmp_path / "bad.go").resolve(), tmp_path) + assert result is None + + def test_instrument_existing_test_reads_file(self, tmp_path: Path) -> None: + support = GoSupport() + test_file = (tmp_path / "calc_test.go").resolve() + test_file.write_text("package calc\n\nfunc TestAdd(t *testing.T) {}\n", encoding="utf-8") + success, content = support.instrument_existing_test( + test_path=test_file, call_positions=[], function_to_optimize=None, tests_project_root=tmp_path, mode="behavior" + ) + assert success is True + assert content is not None + assert "TestAdd" in content + + def test_instrument_existing_test_missing_file(self, tmp_path: Path) -> None: + support = GoSupport() + success, content = support.instrument_existing_test( + test_path=(tmp_path / "missing.go").resolve(), + call_positions=[], + function_to_optimize=None, + tests_project_root=tmp_path, + mode="behavior", + ) + assert success is False + assert content is None + + def test_postprocess_generated_tests_passthrough(self) -> None: + support = GoSupport() + sentinel = object() + result = support.postprocess_generated_tests(sentinel, "go-test", Path("/project"), Path("/project/calc.go")) # type: ignore[arg-type] + assert result is sentinel + + def test_process_generated_test_strings_passthrough(self) -> None: + support = GoSupport() + gen, beh, perf = support.process_generated_test_strings( + "gen_code", "beh_code", "perf_code", None, Path("/test.go"), None, None + ) + assert gen == "gen_code" + assert beh == "beh_code" + assert perf == "perf_code" + + def test_add_runtime_comments_to_generated_tests_passthrough(self) -> None: + support = GoSupport() + sentinel = object() + result = support.add_runtime_comments_to_generated_tests(sentinel, {}, {}) # type: ignore[arg-type] + assert result is sentinel + + def test_remove_test_functions_from_generated_tests(self) -> None: + from codeflash.models.models import GeneratedTests, GeneratedTestsList + + support = GoSupport() + source = """\ +package calc + +import "testing" + +func TestAdd(t *testing.T) { +\tif Add(1, 2) != 3 { +\t\tt.Fatal("bad") +\t} +} + +func TestSub(t *testing.T) { +\tif Sub(3, 1) != 2 { +\t\tt.Fatal("bad") +\t} +} +""" + gt = GeneratedTests( + generated_original_test_source=source, + instrumented_behavior_test_source=source, + instrumented_perf_test_source=source, + behavior_file_path=Path("/test_beh.go"), + perf_file_path=Path("/test_perf.go"), + ) + tests_list = GeneratedTestsList(generated_tests=[gt]) + result = support.remove_test_functions_from_generated_tests(tests_list, ["TestSub"]) + assert "TestAdd" in result.generated_tests[0].generated_original_test_source + assert "TestSub" not in result.generated_tests[0].generated_original_test_source diff --git a/tests/test_languages/test_golang/test_test_discovery.py b/tests/test_languages/test_golang/test_test_discovery.py new file mode 100644 index 000000000..c31c8d006 --- /dev/null +++ b/tests/test_languages/test_golang/test_test_discovery.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.golang.test_discovery import ( + _extract_target_name, + _extract_test_body, + _test_calls_function, + discover_tests, +) +from codeflash.models.function_types import FunctionToOptimize + +GO_TEST_SOURCE = """\ +package calc + +import "testing" + +func TestAdd(t *testing.T) { +\tresult := Add(2, 3) +\tif result != 5 { +\t\tt.Fail() +\t} +} + +func TestSubtract(t *testing.T) { +\tresult := Subtract(5, 3) +\tif result != 2 { +\t\tt.Fail() +\t} +} + +func TestHelper(t *testing.T) { +\tx := 1 + 2 +\t_ = x +} +""" + + +class TestExtractTargetName: + def test_simple(self) -> None: + assert _extract_target_name("TestAdd") == "Add" + + def test_with_underscore_suffix(self) -> None: + assert _extract_target_name("TestAdd_negative") == "Add" + + def test_long_name(self) -> None: + assert _extract_target_name("TestFibonacci") == "Fibonacci" + + def test_bare_test(self) -> None: + assert _extract_target_name("Test") is None + + def test_not_a_test(self) -> None: + assert _extract_target_name("NotATest") is None + + +class TestExtractTestBody: + def test_extracts_body(self) -> None: + body = _extract_test_body(GO_TEST_SOURCE, "TestAdd") + assert body == "\n\tresult := Add(2, 3)\n\tif result != 5 {\n\t\tt.Fail()\n\t}\n" + + def test_extracts_second_body(self) -> None: + body = _extract_test_body(GO_TEST_SOURCE, "TestSubtract") + assert body == "\n\tresult := Subtract(5, 3)\n\tif result != 2 {\n\t\tt.Fail()\n\t}\n" + + def test_missing_function(self) -> None: + assert _extract_test_body(GO_TEST_SOURCE, "TestMissing") is None + + +class TestTestCallsFunction: + def test_calls_add(self) -> None: + assert _test_calls_function(GO_TEST_SOURCE, "TestAdd", "Add") is True + + def test_does_not_call_subtract(self) -> None: + assert _test_calls_function(GO_TEST_SOURCE, "TestAdd", "Subtract") is False + + def test_helper_does_not_call_add(self) -> None: + assert _test_calls_function(GO_TEST_SOURCE, "TestHelper", "Add") is False + + +class TestDiscoverTests: + def test_matches_by_name_convention(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + (root / "calc.go").write_text( + "package calc\n\nfunc Add(a, b int) int { return a + b }\n", encoding="utf-8" + ) + (root / "calc_test.go").write_text(GO_TEST_SOURCE, encoding="utf-8") + + funcs = [FunctionToOptimize(function_name="Add", file_path=root / "calc.go", language="go")] + result = discover_tests(root, funcs) + assert "Add" in result + assert len(result["Add"]) == 1 + assert result["Add"][0].test_name == "TestAdd" + + def test_matches_multiple_functions(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + (root / "calc.go").write_text( + "package calc\n\nfunc Add(a, b int) int { return a + b }\n\nfunc Subtract(a, b int) int { return a - b }\n", + encoding="utf-8", + ) + (root / "calc_test.go").write_text(GO_TEST_SOURCE, encoding="utf-8") + + funcs = [ + FunctionToOptimize(function_name="Add", file_path=root / "calc.go", language="go"), + FunctionToOptimize(function_name="Subtract", file_path=root / "calc.go", language="go"), + ] + result = discover_tests(root, funcs) + assert "Add" in result + assert "Subtract" in result + assert result["Add"][0].test_name == "TestAdd" + assert result["Subtract"][0].test_name == "TestSubtract" + + def test_no_match_returns_empty(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + (root / "calc.go").write_text( + "package calc\n\nfunc Multiply(a, b int) int { return a * b }\n", encoding="utf-8" + ) + (root / "calc_test.go").write_text(GO_TEST_SOURCE, encoding="utf-8") + + funcs = [FunctionToOptimize(function_name="Multiply", file_path=root / "calc.go", language="go")] + result = discover_tests(root, funcs) + assert "Multiply" not in result + + def test_no_test_files(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + (root / "calc.go").write_text("package calc\n\nfunc Add(a, b int) int { return a + b }\n", encoding="utf-8") + + funcs = [FunctionToOptimize(function_name="Add", file_path=root / "calc.go", language="go")] + result = discover_tests(root, funcs) + assert result == {} + + def test_subdirectory_test_files(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + pkg = root / "pkg" + pkg.mkdir() + (pkg / "calc.go").write_text( + "package calc\n\nfunc Add(a, b int) int { return a + b }\n", encoding="utf-8" + ) + (pkg / "calc_test.go").write_text(GO_TEST_SOURCE, encoding="utf-8") + + funcs = [FunctionToOptimize(function_name="Add", file_path=pkg / "calc.go", language="go")] + result = discover_tests(root, funcs) + assert "Add" in result + assert result["Add"][0].test_file == pkg / "calc_test.go" + + def test_fallback_content_match(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + (root / "calc.go").write_text( + "package calc\n\nfunc DoMath(a, b int) int { return a + b }\n", encoding="utf-8" + ) + (root / "calc_test.go").write_text( + 'package calc\n\nimport "testing"\n\nfunc TestComputation(t *testing.T) {\n' + "\tresult := DoMath(2, 3)\n\tif result != 5 {\n\t\tt.Fail()\n\t}\n}\n", + encoding="utf-8", + ) + + funcs = [FunctionToOptimize(function_name="DoMath", file_path=root / "calc.go", language="go")] + result = discover_tests(root, funcs) + assert "DoMath" in result + assert result["DoMath"][0].test_name == "TestComputation" diff --git a/tests/test_languages/test_golang/test_test_runner.py b/tests/test_languages/test_golang/test_test_runner.py new file mode 100644 index 000000000..dea932f3c --- /dev/null +++ b/tests/test_languages/test_golang/test_test_runner.py @@ -0,0 +1,475 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.golang.test_runner import ( + _build_bench_regex, + _build_run_regex, + _collect_other_test_files, + _deduplicate_test_func_names, + _deduplicated_test_files, + _extract_func_names, + _hide_other_test_files, + _test_files_to_packages, + _BENCH_FUNC_RE, + _TEST_FUNC_RE, + parse_go_test_json, + parse_test_results, +) + + +GO_TEST_JSON_ALL_PASS = """\ +{"Time":"2024-01-01T00:00:00Z","Action":"run","Package":"example.com/calc","Test":"TestAdd"} +{"Time":"2024-01-01T00:00:00Z","Action":"output","Package":"example.com/calc","Test":"TestAdd","Output":"=== RUN TestAdd\\n"} +{"Time":"2024-01-01T00:00:00Z","Action":"output","Package":"example.com/calc","Test":"TestAdd","Output":"--- PASS: TestAdd (0.00s)\\n"} +{"Time":"2024-01-01T00:00:00Z","Action":"pass","Package":"example.com/calc","Test":"TestAdd","Elapsed":0.001} +{"Time":"2024-01-01T00:00:00Z","Action":"run","Package":"example.com/calc","Test":"TestSub"} +{"Time":"2024-01-01T00:00:00Z","Action":"output","Package":"example.com/calc","Test":"TestSub","Output":"--- PASS: TestSub (0.00s)\\n"} +{"Time":"2024-01-01T00:00:00Z","Action":"pass","Package":"example.com/calc","Test":"TestSub","Elapsed":0.002} +""" + +GO_TEST_JSON_WITH_FAILURE = """\ +{"Time":"2024-01-01T00:00:00Z","Action":"run","Package":"example.com/calc","Test":"TestAdd"} +{"Time":"2024-01-01T00:00:00Z","Action":"pass","Package":"example.com/calc","Test":"TestAdd","Elapsed":0.001} +{"Time":"2024-01-01T00:00:00Z","Action":"run","Package":"example.com/calc","Test":"TestBroken"} +{"Time":"2024-01-01T00:00:00Z","Action":"output","Package":"example.com/calc","Test":"TestBroken","Output":" calc_test.go:15: expected 5, got 3\\n"} +{"Time":"2024-01-01T00:00:00Z","Action":"fail","Package":"example.com/calc","Test":"TestBroken","Elapsed":0.003} +""" + + +class TestParseGoTestJson: + def test_all_pass(self) -> None: + results = parse_go_test_json(GO_TEST_JSON_ALL_PASS) + assert len(results) == 2 + by_name = {r.test_name: r for r in results} + assert by_name["TestAdd"].passed is True + assert by_name["TestAdd"].runtime_ns == 1_000_000 + assert by_name["TestSub"].passed is True + assert by_name["TestSub"].runtime_ns == 2_000_000 + + def test_with_failure(self) -> None: + results = parse_go_test_json(GO_TEST_JSON_WITH_FAILURE) + assert len(results) == 2 + by_name = {r.test_name: r for r in results} + assert by_name["TestAdd"].passed is True + assert by_name["TestBroken"].passed is False + assert by_name["TestBroken"].error_message == "Test TestBroken failed" + + def test_empty_input(self) -> None: + results = parse_go_test_json("") + assert results == [] + + def test_invalid_json_lines_skipped(self) -> None: + json_output = 'not json\n{"Action":"pass","Package":"calc","Test":"TestOk","Elapsed":0.001}\n' + results = parse_go_test_json(json_output) + assert len(results) == 1 + assert results[0].test_name == "TestOk" + assert results[0].passed is True + + def test_package_level_events_ignored(self) -> None: + json_output = '{"Action":"pass","Package":"example.com/calc","Elapsed":0.5}\n' + results = parse_go_test_json(json_output) + assert results == [] + + def test_runtime_ns_conversion(self) -> None: + json_output = '{"Action":"pass","Package":"calc","Test":"TestFast","Elapsed":0.0005}\n' + results = parse_go_test_json(json_output) + assert len(results) == 1 + assert results[0].runtime_ns == 500_000 + + def test_zero_elapsed(self) -> None: + json_output = '{"Action":"pass","Package":"calc","Test":"TestZero","Elapsed":0}\n' + results = parse_go_test_json(json_output) + assert len(results) == 1 + assert results[0].runtime_ns is None + + +class TestParseTestResults: + def test_reads_from_file(self, tmp_path: Path) -> None: + json_file = (tmp_path / "results.jsonl").resolve() + json_file.write_text('{"Action":"pass","Package":"calc","Test":"TestAdd","Elapsed":0.001}\n', encoding="utf-8") + results = parse_test_results(json_file, "") + assert len(results) == 1 + assert results[0].test_name == "TestAdd" + assert results[0].passed is True + + def test_falls_back_to_stdout(self, tmp_path: Path) -> None: + missing_file = (tmp_path / "missing.jsonl").resolve() + stdout = '{"Action":"fail","Package":"calc","Test":"TestBad","Elapsed":0.002}\n' + results = parse_test_results(missing_file, stdout) + assert len(results) == 1 + assert results[0].test_name == "TestBad" + assert results[0].passed is False + + +class TestCollectOtherTestFiles: + def test_finds_other_test_files_in_same_dir(self, tmp_path: Path) -> None: + keep = (tmp_path / "instrumented_test.go").resolve() + keep.write_text("package x", encoding="utf-8") + other1 = (tmp_path / "sorting_test.go").resolve() + other1.write_text("package x", encoding="utf-8") + other2 = (tmp_path / "perf_test.go").resolve() + other2.write_text("package x", encoding="utf-8") + + result = _collect_other_test_files([keep]) + resolved = {f.resolve() for f in result} + assert other1.resolve() in resolved + assert other2.resolve() in resolved + assert keep.resolve() not in resolved + + def test_keeps_only_specified_files(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text("package x", encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text("package x", encoding="utf-8") + + result = _collect_other_test_files([f1, f2]) + assert result == [] + + def test_ignores_non_test_files(self, tmp_path: Path) -> None: + keep = (tmp_path / "target_test.go").resolve() + keep.write_text("package x", encoding="utf-8") + non_test = (tmp_path / "helper.go").resolve() + non_test.write_text("package x", encoding="utf-8") + + result = _collect_other_test_files([keep]) + assert all(f.name.endswith("_test.go") for f in result) + assert non_test not in result + + def test_empty_list(self) -> None: + assert _collect_other_test_files([]) == [] + + def test_multiple_dirs(self, tmp_path: Path) -> None: + d1 = (tmp_path / "pkg1").resolve() + d1.mkdir() + d2 = (tmp_path / "pkg2").resolve() + d2.mkdir() + keep1 = (d1 / "target_test.go").resolve() + keep1.write_text("package pkg1", encoding="utf-8") + other1 = (d1 / "old_test.go").resolve() + other1.write_text("package pkg1", encoding="utf-8") + keep2 = (d2 / "target_test.go").resolve() + keep2.write_text("package pkg2", encoding="utf-8") + + result = _collect_other_test_files([keep1, keep2]) + resolved = {f.resolve() for f in result} + assert other1.resolve() in resolved + assert keep1.resolve() not in resolved + assert keep2.resolve() not in resolved + + +class TestHideOtherTestFiles: + def test_hides_and_restores(self, tmp_path: Path) -> None: + other = (tmp_path / "sorting_test.go").resolve() + other.write_text("package x\n\nfunc TestSort(t *testing.T) {}", encoding="utf-8") + + with _hide_other_test_files([other]): + assert not other.exists() + assert other.with_suffix(".go.codeflash_hidden").exists() + + assert other.exists() + assert not other.with_suffix(".go.codeflash_hidden").exists() + assert other.read_text(encoding="utf-8") == "package x\n\nfunc TestSort(t *testing.T) {}" + + def test_restores_even_on_exception(self, tmp_path: Path) -> None: + other = (tmp_path / "sorting_test.go").resolve() + other.write_text("content", encoding="utf-8") + + try: + with _hide_other_test_files([other]): + raise RuntimeError("boom") + except RuntimeError: + pass + + assert other.exists() + assert not other.with_suffix(".go.codeflash_hidden").exists() + + def test_empty_list_is_noop(self) -> None: + with _hide_other_test_files([]): + pass + + def test_multiple_files(self, tmp_path: Path) -> None: + files = [] + for name in ("a_test.go", "b_test.go"): + f = (tmp_path / name).resolve() + f.write_text(f"package {name}", encoding="utf-8") + files.append(f) + + with _hide_other_test_files(files): + for f in files: + assert not f.exists() + + for f in files: + assert f.exists() + + +GO_TEST_SOURCE = """\ +package sorting + +import "testing" + +func TestBubbleSort_Basic(t *testing.T) {} +func TestBubbleSort_EdgeCases(t *testing.T) {} +""" + +GO_BENCH_SOURCE = """\ +package sorting + +import "testing" + +func BenchmarkBubbleSort(b *testing.B) {} +func BenchmarkBubbleSort_Large(b *testing.B) {} +""" + +GO_MIXED_SOURCE = """\ +package sorting + +import "testing" + +func TestBubbleSort(t *testing.T) {} +func BenchmarkBubbleSort(b *testing.B) {} +""" + + +class TestExtractFuncNames: + def test_extracts_test_funcs(self, tmp_path: Path) -> None: + f = (tmp_path / "sorting_test.go").resolve() + f.write_text(GO_TEST_SOURCE, encoding="utf-8") + names = _extract_func_names([f], _TEST_FUNC_RE) + assert names == ["TestBubbleSort_Basic", "TestBubbleSort_EdgeCases"] + + def test_extracts_bench_funcs(self, tmp_path: Path) -> None: + f = (tmp_path / "sorting_test.go").resolve() + f.write_text(GO_BENCH_SOURCE, encoding="utf-8") + names = _extract_func_names([f], _BENCH_FUNC_RE) + assert names == ["BenchmarkBubbleSort", "BenchmarkBubbleSort_Large"] + + def test_test_regex_does_not_match_benchmarks(self, tmp_path: Path) -> None: + f = (tmp_path / "sorting_test.go").resolve() + f.write_text(GO_BENCH_SOURCE, encoding="utf-8") + names = _extract_func_names([f], _TEST_FUNC_RE) + assert names == [] + + def test_multiple_files(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text("package x\nfunc TestA(t *testing.T) {}", encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text("package x\nfunc TestB(t *testing.T) {}", encoding="utf-8") + names = _extract_func_names([f1, f2], _TEST_FUNC_RE) + assert names == ["TestA", "TestB"] + + def test_missing_file_skipped(self, tmp_path: Path) -> None: + missing = (tmp_path / "missing_test.go").resolve() + names = _extract_func_names([missing], _TEST_FUNC_RE) + assert names == [] + + def test_empty_list(self) -> None: + assert _extract_func_names([], _TEST_FUNC_RE) == [] + + +class TestBuildRunRegex: + def test_single_test_func(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text("package x\nfunc TestFoo(t *testing.T) {}", encoding="utf-8") + regex = _build_run_regex([f]) + assert regex == "^(TestFoo)$" + + def test_multiple_test_funcs(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text(GO_TEST_SOURCE, encoding="utf-8") + regex = _build_run_regex([f]) + assert regex == "^(TestBubbleSort_Basic|TestBubbleSort_EdgeCases)$" + + def test_no_test_funcs_returns_none(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text("package x\nfunc helper() {}", encoding="utf-8") + assert _build_run_regex([f]) is None + + def test_empty_files_returns_none(self) -> None: + assert _build_run_regex([]) is None + + +class TestBuildBenchRegex: + def test_single_bench_func(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text('package x\nimport "testing"\nfunc BenchmarkFoo(b *testing.B) {}', encoding="utf-8") + regex = _build_bench_regex([f]) + assert regex == "^(BenchmarkFoo)$" + + def test_no_bench_funcs_returns_none(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text(GO_TEST_SOURCE, encoding="utf-8") + assert _build_bench_regex([f]) is None + + +class TestTestFilesToPackages: + def test_subdirectory(self, tmp_path: Path) -> None: + subdir = (tmp_path / "sorting").resolve() + subdir.mkdir() + f = subdir / "sorting_test.go" + f.write_text("package sorting", encoding="utf-8") + packages = _test_files_to_packages([f.resolve()], tmp_path.resolve()) + assert packages == ["./sorting"] + + def test_root_directory(self, tmp_path: Path) -> None: + f = (tmp_path / "main_test.go").resolve() + f.write_text("package main", encoding="utf-8") + packages = _test_files_to_packages([f], tmp_path.resolve()) + assert packages == ["."] + + def test_deduplicates_same_package(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text("package x", encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text("package x", encoding="utf-8") + packages = _test_files_to_packages([f1, f2], tmp_path.resolve()) + assert packages == ["."] + + def test_multiple_packages(self, tmp_path: Path) -> None: + for name in ("pkg1", "pkg2"): + d = (tmp_path / name).resolve() + d.mkdir() + (d / "x_test.go").write_text(f"package {name}", encoding="utf-8") + f1 = (tmp_path / "pkg1" / "x_test.go").resolve() + f2 = (tmp_path / "pkg2" / "x_test.go").resolve() + packages = _test_files_to_packages([f1, f2], tmp_path.resolve()) + assert packages == ["./pkg1", "./pkg2"] + + def test_empty_list(self, tmp_path: Path) -> None: + assert _test_files_to_packages([], tmp_path.resolve()) == [] + + def test_file_outside_cwd_skipped(self, tmp_path: Path) -> None: + other = (tmp_path / "other").resolve() + other.mkdir() + f = (other / "x_test.go").resolve() + f.write_text("package x", encoding="utf-8") + cwd = (tmp_path / "project").resolve() + cwd.mkdir() + assert _test_files_to_packages([f], cwd) == [] + + +GO_FILE_A = """\ +package x + +import "testing" + +func TestFoo(t *testing.T) {} +func TestBar(t *testing.T) {} +""" + +GO_FILE_B_DUPLICATES = """\ +package x + +import "testing" + +func TestFoo(t *testing.T) {} +func TestBaz(t *testing.T) {} +""" + +GO_FILE_C_MORE_DUPLICATES = """\ +package x + +import "testing" + +func TestFoo(t *testing.T) {} +func TestBar(t *testing.T) {} +func TestNew(t *testing.T) {} +""" + + +class TestDeduplicateTestFuncNames: + def test_no_duplicates_no_changes(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text(GO_FILE_A, encoding="utf-8") + originals = _deduplicate_test_func_names([f1]) + assert originals == {} + assert f1.read_text(encoding="utf-8") == GO_FILE_A + + def test_renames_duplicates_in_second_file(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text(GO_FILE_A, encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text(GO_FILE_B_DUPLICATES, encoding="utf-8") + + originals = _deduplicate_test_func_names([f1, f2]) + + assert f1.read_text(encoding="utf-8") == GO_FILE_A + assert f2 in originals + assert originals[f2] == GO_FILE_B_DUPLICATES + + rewritten = f2.read_text(encoding="utf-8") + assert "func TestFoo_1(" in rewritten + assert "func TestBaz(" in rewritten + assert "func TestFoo(" not in rewritten + + def test_renames_across_three_files(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text(GO_FILE_A, encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text(GO_FILE_B_DUPLICATES, encoding="utf-8") + f3 = (tmp_path / "c_test.go").resolve() + f3.write_text(GO_FILE_C_MORE_DUPLICATES, encoding="utf-8") + + _deduplicate_test_func_names([f1, f2, f3]) + + rewritten_b = f2.read_text(encoding="utf-8") + rewritten_c = f3.read_text(encoding="utf-8") + + assert "func TestFoo_1(" in rewritten_b + assert "func TestFoo_2(" in rewritten_c + assert "func TestBar_1(" in rewritten_c + + def test_empty_list(self) -> None: + assert _deduplicate_test_func_names([]) == {} + + def test_single_file_no_changes(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text(GO_FILE_B_DUPLICATES, encoding="utf-8") + originals = _deduplicate_test_func_names([f]) + assert originals == {} + + def test_benchmarks_also_deduplicated(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text('package x\n\nimport "testing"\n\nfunc BenchmarkFoo(b *testing.B) {}\n', encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text('package x\n\nimport "testing"\n\nfunc BenchmarkFoo(b *testing.B) {}\n', encoding="utf-8") + + _deduplicate_test_func_names([f1, f2]) + + assert "func BenchmarkFoo(" in f1.read_text(encoding="utf-8") + rewritten = f2.read_text(encoding="utf-8") + assert "func BenchmarkFoo_1(" in rewritten + + +class TestDeduplicatedTestFiles: + def test_restores_after_context(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text(GO_FILE_A, encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text(GO_FILE_B_DUPLICATES, encoding="utf-8") + + with _deduplicated_test_files([f1, f2]): + assert "func TestFoo_1(" in f2.read_text(encoding="utf-8") + + assert f2.read_text(encoding="utf-8") == GO_FILE_B_DUPLICATES + + def test_restores_on_exception(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text(GO_FILE_A, encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text(GO_FILE_B_DUPLICATES, encoding="utf-8") + + try: + with _deduplicated_test_files([f1, f2]): + raise RuntimeError("boom") + except RuntimeError: + pass + + assert f2.read_text(encoding="utf-8") == GO_FILE_B_DUPLICATES + + def test_no_duplicates_is_noop(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text(GO_FILE_A, encoding="utf-8") + + with _deduplicated_test_files([f]): + assert f.read_text(encoding="utf-8") == GO_FILE_A diff --git a/uv.lock b/uv.lock index 6a2b2a0f0..7f858adaf 100644 --- a/uv.lock +++ b/uv.lock @@ -500,6 +500,8 @@ dependencies = [ { name = "tomlkit" }, { name = "tree-sitter", version = "0.23.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "tree-sitter", version = "0.25.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "tree-sitter-go", version = "0.23.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "tree-sitter-go", version = "0.25.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "tree-sitter-groovy" }, { name = "tree-sitter-java" }, { name = "tree-sitter-javascript", version = "0.23.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, @@ -616,6 +618,7 @@ requires-dist = [ { name = "sentry-sdk", specifier = ">=2.58.0,<3.0.0" }, { name = "tomlkit", specifier = ">=0.14.0" }, { name = "tree-sitter", specifier = ">=0.23.2" }, + { name = "tree-sitter-go", specifier = ">=0.23.0" }, { name = "tree-sitter-groovy", specifier = ">=0.1.2" }, { name = "tree-sitter-java", specifier = ">=0.23.5" }, { name = "tree-sitter-javascript", specifier = ">=0.23.1" }, @@ -6112,6 +6115,56 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/6e/e64621037357acb83d912276ffd30a859ef117f9c680f2e3cb955f47c680/tree_sitter-0.25.2-cp314-cp314-win_arm64.whl", hash = "sha256:b8d4429954a3beb3e844e2872610d2a4800ba4eb42bb1990c6a4b1949b18459f", size = 117470, upload-time = "2025-09-25T17:37:58.431Z" }, ] +[[package]] +name = "tree-sitter-go" +version = "0.23.4" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.9.2' and python_full_version < '3.10'", + "python_full_version < '3.9.2'", +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/7f/13b83b877043faadecb5cb70982589ed79e7ebd78f8d239128dc6b23f595/tree_sitter_go-0.23.4.tar.gz", hash = "sha256:0ebff99820657066bec21690623a14c74d9e57a903f95f0837be112ddadf1a52", size = 85686, upload-time = "2024-11-24T19:37:18.235Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/2d/070137fa47215265459bef90b27902471ddcd61530c3331437bcd9ba93cd/tree_sitter_go-0.23.4-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c9320f87a05cd47fa0f627b9329bbc09b7ed90de8fe4f5882aed318d6e19962d", size = 45689, upload-time = "2024-11-24T19:37:07.228Z" }, + { url = "https://files.pythonhosted.org/packages/37/8a/9e1dc1c1cefcf060b0105fb294c399ec4808fa1f9e2cbf0463f991b28aed/tree_sitter_go-0.23.4-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:914e63d16b36ab0e4f52b031e574b82d17d0bbfecca138ae83e887a1cf5b71ac", size = 47364, upload-time = "2024-11-24T19:37:08.835Z" }, + { url = "https://files.pythonhosted.org/packages/d6/8a/6c1f26d25cfcedd22d452a299bf9a753d97d5ebd8db4d2047f2002b5b301/tree_sitter_go-0.23.4-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:330ecbb38d6ea4ef41eba2d473056889705e64f6a51c2fb613de05b1bcb5ba22", size = 66543, upload-time = "2024-11-24T19:37:10.738Z" }, + { url = "https://files.pythonhosted.org/packages/f2/03/d82c4b61db9e29b272aed6742cde37244312e63860048fd66d927bfc4f50/tree_sitter_go-0.23.4-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd14d23056ae980debfccc0db67d0a168da03792ca2968b1b5dd58ce288084e7", size = 65498, upload-time = "2024-11-24T19:37:12.375Z" }, + { url = "https://files.pythonhosted.org/packages/03/15/c37db75ff873042f74b1eec214fda84dfff985406ccdc94e4d2be9a6888b/tree_sitter_go-0.23.4-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:c3b40912487fdb78c4028860dd79493a521ffca0104f209849823358db3618a0", size = 64391, upload-time = "2024-11-24T19:37:13.944Z" }, + { url = "https://files.pythonhosted.org/packages/e3/cc/a32de9c9391a859dd5fc938922bb6cd5b7d6114c88998411433e06fe4572/tree_sitter_go-0.23.4-cp39-abi3-win_amd64.whl", hash = "sha256:ae4b231cad2ef76401d33617879cda6321c4d0853f7fd98cb5654c50a218effb", size = 46954, upload-time = "2024-11-24T19:37:14.953Z" }, + { url = "https://files.pythonhosted.org/packages/ec/35/a533173cd846385796eed56dde62eb908b3500e6308fddb4ddc30dc227b8/tree_sitter_go-0.23.4-cp39-abi3-win_arm64.whl", hash = "sha256:2ac907362a3c347145dc1da0858248546500a323de90d2cb76d2a3fdbfc8da25", size = 45276, upload-time = "2024-11-24T19:37:16.623Z" }, +] + +[[package]] +name = "tree-sitter-go" +version = "0.25.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.14' and sys_platform == 'emscripten'", + "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'emscripten'", + "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.10.*'", +] +sdist = { url = "https://files.pythonhosted.org/packages/01/05/727308adbbc79bcb1c92fc0ea10556a735f9d0f0a5435a18f59d40f7fd77/tree_sitter_go-0.25.0.tar.gz", hash = "sha256:a7466e9b8d94dda94cae8d91629f26edb2d26166fd454d4831c3bf6dfa2e8d68", size = 93890, upload-time = "2025-08-29T06:20:25.044Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/aa/0984707acc2b9bb461fe4a41e7e0fc5b2b1e245c32820f0c83b3c602957c/tree_sitter_go-0.25.0-cp310-abi3-macosx_10_9_x86_64.whl", hash = "sha256:b852993063a3429a443e7bd0aa376dd7dd329d595819fabf56ac4cf9d7257b54", size = 47117, upload-time = "2025-08-29T06:20:14.286Z" }, + { url = "https://files.pythonhosted.org/packages/32/16/dd4cb124b35e99239ab3624225da07d4cb8da4d8564ed81d03fcb3a6ba9f/tree_sitter_go-0.25.0-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:503b81a2b4c31e302869a1de3a352ad0912ccab3df9ac9950197b0a9ceeabd8f", size = 48674, upload-time = "2025-08-29T06:20:17.557Z" }, + { url = "https://files.pythonhosted.org/packages/86/fb/b30d63a08044115d8b8bd196c6c2ab4325fb8db5757249a4ef0563966e2e/tree_sitter_go-0.25.0-cp310-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:04b3b3cb4aff18e74e28d49b716c6f24cb71ddfdd66768987e26e4d0fa812f74", size = 66418, upload-time = "2025-08-29T06:20:18.345Z" }, + { url = "https://files.pythonhosted.org/packages/26/21/d3d88a30ad007419b2c97b3baeeef7431407faf9f686195b6f1cad0aedf9/tree_sitter_go-0.25.0-cp310-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:148255aca2f54b90d48c48a9dbb4c7faad6cad310a980b2c5a5a9822057ed145", size = 72006, upload-time = "2025-08-29T06:20:19.14Z" }, + { url = "https://files.pythonhosted.org/packages/cd/d0/0dd6442353ced8a88bbda9e546f4ea29e381b59b5a40b122e5abb586bb6c/tree_sitter_go-0.25.0-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4d338116cdf8a6c6ff990d2441929b41323ef17c710407abe0993c13417d6aad", size = 70603, upload-time = "2025-08-29T06:20:21.544Z" }, + { url = "https://files.pythonhosted.org/packages/01/e2/ee5e09f63504fc286539535d374d2eaa0e7d489b80f8f744bb3962aff22a/tree_sitter_go-0.25.0-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5608e089d2a29fa8d2b327abeb2ad1cdb8e223c440a6b0ceab0d3fa80bdeebae", size = 66088, upload-time = "2025-08-29T06:20:22.336Z" }, + { url = "https://files.pythonhosted.org/packages/6e/b6/d9142583374720e79aca9ccb394b3795149a54c012e1dfd80738df2d984e/tree_sitter_go-0.25.0-cp310-abi3-win_amd64.whl", hash = "sha256:30d4ada57a223dfc2c32d942f44d284d40f3d1215ddcf108f96807fd36d53022", size = 48152, upload-time = "2025-08-29T06:20:23.089Z" }, + { url = "https://files.pythonhosted.org/packages/9e/00/9a2638e7339236f5b01622952a4d71c1474dd3783d1982a89555fc1f03b1/tree_sitter_go-0.25.0-cp310-abi3-win_arm64.whl", hash = "sha256:d5d62362059bf79997340773d47cc7e7e002883b527a05cca829c46e40b70ded", size = 46752, upload-time = "2025-08-29T06:20:24.235Z" }, +] + [[package]] name = "tree-sitter-groovy" version = "0.1.2"