From 955b7088a568ff2bcfffd36868262699f22e6ec3 Mon Sep 17 00:00:00 2001 From: gmegidish Date: Sun, 25 Jan 2026 11:25:47 +0100 Subject: [PATCH 1/3] feat: websocket support through /ws --- go.mod | 1 + go.sum | 2 + server/server.go | 3 + server/websocket.go | 154 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 160 insertions(+) create mode 100644 server/websocket.go diff --git a/go.mod b/go.mod index 7e53445..5883ea2 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25.0 require ( github.com/danielpaulus/go-ios v1.0.182 + github.com/gorilla/websocket v1.5.3 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.10.0 diff --git a/go.sum b/go.sum index 1a48f5a..d921599 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grandcat/zeroconf v1.0.0 h1:uHhahLBKqwWBV6WZUDAT71044vwOTL+McW0mBJvo6kE= github.com/grandcat/zeroconf v1.0.0/go.mod h1:lTKmG1zh86XyCoUeIHSA4FJMBwCJiQmGfcP2PdzytEs= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= diff --git a/server/server.go b/server/server.go index cd2c229..7d7499d 100644 --- a/server/server.go +++ b/server/server.go @@ -99,6 +99,9 @@ func StartServer(addr string, enableCORS bool) error { mux.HandleFunc("/", sendBanner) mux.HandleFunc("/rpc", handleJSONRPC) + mux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { + handleWebSocket(w, r, enableCORS) + }) // if host is missing, default to localhost if !strings.Contains(addr, ":") { diff --git a/server/websocket.go b/server/websocket.go new file mode 100644 index 0000000..323ee4f --- /dev/null +++ b/server/websocket.go @@ -0,0 +1,154 @@ +package server + +import ( + "encoding/json" + "log" + "net/http" + "net/url" + "sync" + + "github.com/gorilla/websocket" + "github.com/mobile-next/mobilecli/utils" +) + +type wsConnection struct { + conn *websocket.Conn + writeMu sync.Mutex +} + +func newUpgrader(enableCORS bool) *websocket.Upgrader { + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + if enableCORS { + upgrader.CheckOrigin = func(r *http.Request) bool { + return true + } + } else { + upgrader.CheckOrigin = isSameOrigin + } + + return &upgrader +} + +func handleWebSocket(w http.ResponseWriter, r *http.Request, enableCORS bool) { + conn, err := newUpgrader(enableCORS).Upgrade(w, r, nil) + if err != nil { + log.Printf("WebSocket upgrade failed: %v", err) + return + } + defer conn.Close() + + wsConn := &wsConnection{conn: conn} + + for { + messageType, message, err := conn.ReadMessage() + if err != nil { + // connection closed or error + utils.Verbose("WebSocket connection closed: %v", err) + break + } + + if messageType != websocket.TextMessage { + wsConn.sendError(nil, ErrCodeInvalidRequest, "Invalid Request", "only text messages accepted for requests") + continue + } + + handleWSMessage(wsConn, message) + } +} + +func isSameOrigin(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return true + } + + originURL, err := url.Parse(origin) + if err != nil { + return false + } + + return originURL.Host == r.Host +} + +func handleWSMessage(wsConn *wsConnection, message []byte) { + var req JSONRPCRequest + if err := json.Unmarshal(message, &req); err != nil { + wsConn.sendError(nil, ErrCodeParseError, "Parse error", "expecting jsonrpc payload") + return + } + + if req.JSONRPC != "2.0" { + wsConn.sendError(req.ID, ErrCodeInvalidRequest, "Invalid Request", "'jsonrpc' must be '2.0'") + return + } + + if req.ID == nil { + wsConn.sendError(nil, ErrCodeInvalidRequest, "Invalid Request", "'id' field is required") + return + } + + if req.Method == "" { + wsConn.sendError(req.ID, ErrCodeInvalidRequest, "Invalid Request", "'method' is required") + return + } + + // screencapture is not supported over WebSocket + if req.Method == "screencapture" { + wsConn.sendError(req.ID, ErrCodeMethodNotFound, "Method not supported", "screencapture not supported over WebSocket, use HTTP /rpc endpoint") + return + } + + utils.Info("WebSocket Request ID: %v, Method: %s, Params: %s", req.ID, req.Method, string(req.Params)) + + handleWSMethodCall(wsConn, req) +} + +func handleWSMethodCall(wsConn *wsConnection, req JSONRPCRequest) { + registry := GetMethodRegistry() + handler, exists := registry[req.Method] + if !exists { + wsConn.sendError(req.ID, ErrCodeMethodNotFound, "Method not found", req.Method+" not found") + return + } + + result, err := handler(req.Params) + if err != nil { + log.Printf("Error executing method %s: %v", req.Method, err) + wsConn.sendError(req.ID, ErrCodeServerError, "Server error", err.Error()) + return + } + + wsConn.sendResponse(req.ID, result) +} + +func (wsc *wsConnection) sendResponse(id interface{}, result interface{}) error { + response := JSONRPCResponse{ + JSONRPC: "2.0", + Result: result, + ID: id, + } + return wsc.sendJSON(response) +} + +func (wsc *wsConnection) sendError(id interface{}, code int, message string, data interface{}) error { + response := JSONRPCResponse{ + JSONRPC: "2.0", + Error: map[string]interface{}{ + "code": code, + "message": message, + "data": data, + }, + ID: id, + } + return wsc.sendJSON(response) +} + +func (wsc *wsConnection) sendJSON(v interface{}) error { + wsc.writeMu.Lock() + defer wsc.writeMu.Unlock() + return wsc.conn.WriteJSON(v) +} From d36d7240ef38af91f812c0cedfb47574ed1eab30 Mon Sep 17 00:00:00 2001 From: gmegidish Date: Sun, 25 Jan 2026 14:39:24 +0100 Subject: [PATCH 2/3] chore: ran clean code skill --- server/server.go | 4 +- server/websocket.go | 155 ++++++++++++++++++++++++++++++++++++-------- 2 files changed, 129 insertions(+), 30 deletions(-) diff --git a/server/server.go b/server/server.go index 7d7499d..0f6fcf9 100644 --- a/server/server.go +++ b/server/server.go @@ -99,9 +99,7 @@ func StartServer(addr string, enableCORS bool) error { mux.HandleFunc("/", sendBanner) mux.HandleFunc("/rpc", handleJSONRPC) - mux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { - handleWebSocket(w, r, enableCORS) - }) + mux.HandleFunc("/ws", NewWebSocketHandler(enableCORS)) // if host is missing, default to localhost if !strings.Contains(addr, ":") { diff --git a/server/websocket.go b/server/websocket.go index 323ee4f..b180f2b 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "sync" + "time" "github.com/gorilla/websocket" "github.com/mobile-next/mobilecli/utils" @@ -16,6 +17,30 @@ type wsConnection struct { writeMu sync.Mutex } +type validationError struct { + code int + message string + data interface{} +} + +const ( + wsMaxMessageSize = 64 * 1024 + wsWriteWait = 10 * time.Second + wsPongWait = 60 * time.Second + wsPingPeriod = (wsPongWait * 9) / 10 + + jsonRPCVersion = "2.0" + errMsgParseError = "expecting jsonrpc payload" + errMsgInvalidJSONRPC = "'jsonrpc' must be '2.0'" + errMsgIDRequired = "'id' field is required" + errMsgMethodRequired = "'method' is required" + errMsgTextOnly = "only text messages accepted for requests" + errMsgScreencapture = "screencapture not supported over WebSocket, use HTTP /rpc endpoint" + errTitleParseError = "Parse error" + errTitleInvalidReq = "Invalid Request" + errTitleMethodNotSupp = "Method not supported" +) + func newUpgrader(enableCORS bool) *websocket.Upgrader { upgrader := websocket.Upgrader{ ReadBufferSize: 1024, @@ -33,26 +58,63 @@ func newUpgrader(enableCORS bool) *websocket.Upgrader { return &upgrader } -func handleWebSocket(w http.ResponseWriter, r *http.Request, enableCORS bool) { - conn, err := newUpgrader(enableCORS).Upgrade(w, r, nil) +func upgradeConnection(w http.ResponseWriter, r *http.Request, upgrader *websocket.Upgrader) (*websocket.Conn, error) { + conn, err := upgrader.Upgrade(w, r, nil) if err != nil { - log.Printf("WebSocket upgrade failed: %v", err) - return + return nil, err + } + return conn, nil +} + +func configureConnection(conn *websocket.Conn) { + conn.SetReadLimit(wsMaxMessageSize) + if err := conn.SetReadDeadline(time.Now().Add(wsPongWait)); err != nil { + utils.Verbose("failed to set read deadline: %v", err) } - defer conn.Close() + conn.SetPongHandler(func(string) error { + return conn.SetReadDeadline(time.Now().Add(wsPongWait)) + }) +} - wsConn := &wsConnection{conn: conn} +func startPingRoutine(wsConn *wsConnection) func() { + pingDone := make(chan struct{}) + go pingLoop(wsConn, pingDone) + return func() { close(pingDone) } +} +func pingLoop(wsConn *wsConnection, done <-chan struct{}) { + ticker := time.NewTicker(wsPingPeriod) + defer ticker.Stop() for { - messageType, message, err := conn.ReadMessage() + select { + case <-ticker.C: + wsConn.writeMu.Lock() + if err := wsConn.conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil { + utils.Verbose("failed to set write deadline: %v", err) + wsConn.writeMu.Unlock() + return + } + if err := wsConn.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + wsConn.writeMu.Unlock() + return + } + wsConn.writeMu.Unlock() + case <-done: + return + } + } +} + +func readMessages(wsConn *wsConnection) { + for { + messageType, message, err := wsConn.conn.ReadMessage() if err != nil { - // connection closed or error utils.Verbose("WebSocket connection closed: %v", err) break } if messageType != websocket.TextMessage { - wsConn.sendError(nil, ErrCodeInvalidRequest, "Invalid Request", "only text messages accepted for requests") + wsConn.sendError(nil, ErrCodeInvalidRequest, errTitleInvalidReq, errMsgTextOnly) continue } @@ -60,6 +122,25 @@ func handleWebSocket(w http.ResponseWriter, r *http.Request, enableCORS bool) { } } +func NewWebSocketHandler(enableCORS bool) http.HandlerFunc { + upgrader := newUpgrader(enableCORS) + return func(w http.ResponseWriter, r *http.Request) { + conn, err := upgradeConnection(w, r, upgrader) + if err != nil { + log.Printf("WebSocket upgrade failed: %v", err) + return + } + defer conn.Close() + + wsConn := &wsConnection{conn: conn} + configureConnection(conn) + stopPing := startPingRoutine(wsConn) + defer stopPing() + + readMessages(wsConn) + } +} + func isSameOrigin(r *http.Request) bool { origin := r.Header.Get("Origin") if origin == "" { @@ -74,31 +155,51 @@ func isSameOrigin(r *http.Request) bool { return originURL.Host == r.Host } -func handleWSMessage(wsConn *wsConnection, message []byte) { - var req JSONRPCRequest - if err := json.Unmarshal(message, &req); err != nil { - wsConn.sendError(nil, ErrCodeParseError, "Parse error", "expecting jsonrpc payload") - return - } - - if req.JSONRPC != "2.0" { - wsConn.sendError(req.ID, ErrCodeInvalidRequest, "Invalid Request", "'jsonrpc' must be '2.0'") - return +func validateJSONRPCRequest(req JSONRPCRequest) *validationError { + if req.JSONRPC != jsonRPCVersion { + return &validationError{ + code: ErrCodeInvalidRequest, + message: errTitleInvalidReq, + data: errMsgInvalidJSONRPC, + } } if req.ID == nil { - wsConn.sendError(nil, ErrCodeInvalidRequest, "Invalid Request", "'id' field is required") - return + return &validationError{ + code: ErrCodeInvalidRequest, + message: errTitleInvalidReq, + data: errMsgIDRequired, + } } if req.Method == "" { - wsConn.sendError(req.ID, ErrCodeInvalidRequest, "Invalid Request", "'method' is required") - return + return &validationError{ + code: ErrCodeInvalidRequest, + message: errTitleInvalidReq, + data: errMsgMethodRequired, + } } - // screencapture is not supported over WebSocket if req.Method == "screencapture" { - wsConn.sendError(req.ID, ErrCodeMethodNotFound, "Method not supported", "screencapture not supported over WebSocket, use HTTP /rpc endpoint") + return &validationError{ + code: ErrCodeMethodNotFound, + message: errTitleMethodNotSupp, + data: errMsgScreencapture, + } + } + + return nil +} + +func handleWSMessage(wsConn *wsConnection, message []byte) { + var req JSONRPCRequest + if err := json.Unmarshal(message, &req); err != nil { + wsConn.sendError(nil, ErrCodeParseError, errTitleParseError, errMsgParseError) + return + } + + if validationErr := validateJSONRPCRequest(req); validationErr != nil { + wsConn.sendError(req.ID, validationErr.code, validationErr.message, validationErr.data) return } @@ -127,7 +228,7 @@ func handleWSMethodCall(wsConn *wsConnection, req JSONRPCRequest) { func (wsc *wsConnection) sendResponse(id interface{}, result interface{}) error { response := JSONRPCResponse{ - JSONRPC: "2.0", + JSONRPC: jsonRPCVersion, Result: result, ID: id, } @@ -136,7 +237,7 @@ func (wsc *wsConnection) sendResponse(id interface{}, result interface{}) error func (wsc *wsConnection) sendError(id interface{}, code int, message string, data interface{}) error { response := JSONRPCResponse{ - JSONRPC: "2.0", + JSONRPC: jsonRPCVersion, Error: map[string]interface{}{ "code": code, "message": message, From a3fdb20382231b2bfd86fde1839e146959f09cda Mon Sep 17 00:00:00 2001 From: gmegidish Date: Sun, 25 Jan 2026 15:55:38 +0100 Subject: [PATCH 3/3] feat: added support for audiocapture on ios --- cli/audiocapture.go | 89 +++++++++++++++++++ cli/flags.go | 2 + commands/audiocapture.go | 6 ++ devices/android.go | 6 +- devices/common.go | 8 ++ devices/ios.go | 95 ++++++++++++++++++++ devices/simulator.go | 4 + server/server.go | 84 ++++++++++++++++++ server/websocket.go | 9 ++ utils/ogg_opus.go | 186 +++++++++++++++++++++++++++++++++++++++ 10 files changed, 488 insertions(+), 1 deletion(-) create mode 100644 cli/audiocapture.go create mode 100644 commands/audiocapture.go create mode 100644 utils/ogg_opus.go diff --git a/cli/audiocapture.go b/cli/audiocapture.go new file mode 100644 index 0000000..3a2592b --- /dev/null +++ b/cli/audiocapture.go @@ -0,0 +1,89 @@ +package cli + +import ( + "fmt" + "os" + + "github.com/mobile-next/mobilecli/commands" + "github.com/mobile-next/mobilecli/devices" + "github.com/mobile-next/mobilecli/utils" + "github.com/spf13/cobra" +) + +var audiocaptureCmd = &cobra.Command{ + Use: "audiocapture", + Short: "Stream audio capture from a connected device", + Long: "Streams audio capture from a specified device to stdout. Supports Opus (real iOS devices only).", + RunE: func(cmd *cobra.Command, args []string) error { + if audiocaptureFormat != "opus+rtp" && audiocaptureFormat != "opus+ogg" { + response := commands.NewErrorResponse(fmt.Errorf("format must be 'opus+rtp' or 'opus+ogg' for audio capture")) + printJson(response) + return fmt.Errorf("%s", response.Error) + } + + targetDevice, err := commands.FindDeviceOrAutoSelect(deviceId) + if err != nil { + response := commands.NewErrorResponse(fmt.Errorf("error finding device: %v", err)) + printJson(response) + return fmt.Errorf("%s", response.Error) + } + + if targetDevice.Platform() != "ios" || targetDevice.DeviceType() != "real" { + response := commands.NewErrorResponse(fmt.Errorf("audio capture is only supported on real iOS devices")) + printJson(response) + return fmt.Errorf("%s", response.Error) + } + + var parser *utils.OpusFrameParser + var oggWriter *utils.OggOpusWriter + if audiocaptureFormat == "opus+ogg" { + var err error + oggWriter, err = utils.NewOggOpusWriter(os.Stdout) + if err != nil { + response := commands.NewErrorResponse(fmt.Errorf("failed to initialize ogg writer: %v", err)) + printJson(response) + return fmt.Errorf("%s", response.Error) + } + parser = utils.NewOpusFrameParser(func(packet []byte) error { + return oggWriter.WritePacket(packet) + }) + } + + err = targetDevice.StartAudioCapture(devices.AudioCaptureConfig{ + Format: audiocaptureFormat, + OnProgress: func(message string) { + utils.Verbose(message) + }, + OnData: func(data []byte) bool { + if parser != nil { + if err := parser.Write(data); err != nil { + fmt.Fprintf(os.Stderr, "Error writing Ogg Opus data: %v\n", err) + return false + } + } else { + _, writeErr := os.Stdout.Write(data) + if writeErr != nil { + fmt.Fprintf(os.Stderr, "Error writing data: %v\n", writeErr) + return false + } + } + return true + }, + }) + + if err != nil { + response := commands.NewErrorResponse(fmt.Errorf("error starting audio capture: %v", err)) + printJson(response) + return fmt.Errorf("%s", response.Error) + } + + return nil + }, +} + +func init() { + rootCmd.AddCommand(audiocaptureCmd) + + audiocaptureCmd.Flags().StringVar(&deviceId, "device", "", "ID of the device to capture from") + audiocaptureCmd.Flags().StringVarP(&audiocaptureFormat, "format", "f", "opus+rtp", "Output format for audio capture (opus+rtp, opus+ogg)") +} diff --git a/cli/flags.go b/cli/flags.go index 3b5dfd9..e50a395 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -13,6 +13,8 @@ var ( // for screencapture command screencaptureFormat string + // for audiocapture command + audiocaptureFormat string // for devices command platform string diff --git a/commands/audiocapture.go b/commands/audiocapture.go new file mode 100644 index 0000000..a9314a3 --- /dev/null +++ b/commands/audiocapture.go @@ -0,0 +1,6 @@ +package commands + +type AudioCaptureRequest struct { + DeviceID string `json:"deviceId"` + Format string `json:"format"` +} diff --git a/devices/android.go b/devices/android.go index 4b432be..d01caec 100644 --- a/devices/android.go +++ b/devices/android.go @@ -628,7 +628,7 @@ func isAscii(text string) bool { // escapeShellText escapes shell special characters func escapeShellText(text string) string { // escape all shell special characters that could be used for injection - specialChars := `\'"`+ "`" + ` + specialChars := `\'"` + "`" + ` |&;()<>{}[]$*?` result := "" for _, char := range text { @@ -849,6 +849,10 @@ func (d *AndroidDevice) StartScreenCapture(config ScreenCaptureConfig) error { return nil } +func (d *AndroidDevice) StartAudioCapture(config AudioCaptureConfig) error { + return fmt.Errorf("audio capture is only supported on real iOS devices") +} + func (d *AndroidDevice) installPackage(apkPath string) error { output, err := d.runAdbCommand("install", apkPath) if err != nil { diff --git a/devices/common.go b/devices/common.go index 20d413e..e71efc7 100644 --- a/devices/common.go +++ b/devices/common.go @@ -28,6 +28,13 @@ type ScreenCaptureConfig struct { OnData func([]byte) bool // data callback - return false to stop } +// AudioCaptureConfig contains configuration for audio capture operations +type AudioCaptureConfig struct { + Format string + OnProgress func(message string) // optional progress callback + OnData func([]byte) bool // data callback - return false to stop +} + // StartAgentConfig contains configuration for agent startup operations type StartAgentConfig struct { OnProgress func(message string) // optional progress callback @@ -65,6 +72,7 @@ type ControllableDevice interface { UninstallApp(packageName string) (*InstalledAppInfo, error) Info() (*FullDeviceInfo, error) StartScreenCapture(config ScreenCaptureConfig) error + StartAudioCapture(config AudioCaptureConfig) error DumpSource() ([]ScreenElement, error) DumpSourceRaw() (interface{}, error) GetOrientation() (string, error) diff --git a/devices/ios.go b/devices/ios.go index e1751f3..76d3978 100644 --- a/devices/ios.go +++ b/devices/ios.go @@ -32,6 +32,7 @@ const ( portRangeEnd = 8299 deviceKitHTTPPort = 12004 // device-side HTTP server port deviceKitStreamPort = 12005 // device-side H.264 TCP stream port + deviceKitAudioPort = 12006 // device-side Opus audio TCP stream port deviceKitAppLaunchTimeout = 5 * time.Second deviceKitBroadcastTimeout = 5 * time.Second ) @@ -775,6 +776,78 @@ func (d IOSDevice) StartScreenCapture(config ScreenCaptureConfig) error { return d.mjpegClient.StartScreenCapture(config.Format, config.OnData) } +func (d IOSDevice) StartAudioCapture(config AudioCaptureConfig) error { + if config.Format != "opus+rtp" && config.Format != "opus+ogg" { + return fmt.Errorf("format must be 'opus+rtp' or 'opus+ogg' for audio capture") + } + + if d.Platform() != "ios" || d.DeviceType() != "real" { + return fmt.Errorf("audio capture is only supported on real iOS devices") + } + + if config.OnProgress != nil { + config.OnProgress("Starting port forwarding for Opus audio stream") + } + + localAudioPort, err := findAvailablePortInRange(portRangeStart, portRangeEnd) + if err != nil { + return fmt.Errorf("failed to find available port for audio: %w", err) + } + + audioForwarder := ios.NewPortForwarder(d.ID()) + err = audioForwarder.Forward(localAudioPort, deviceKitAudioPort) + if err != nil { + return fmt.Errorf("failed to forward audio port: %w", err) + } + defer func() { _ = audioForwarder.Stop() }() + + if config.OnProgress != nil { + config.OnProgress(fmt.Sprintf("Connecting to Opus stream on localhost:%d", localAudioPort)) + } + + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", localAudioPort)) + if err != nil { + return fmt.Errorf("failed to connect to audio stream port: %w", err) + } + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + done := make(chan error, 1) + go func() { + defer conn.Close() + buffer := make([]byte, 65536) + for { + n, err := conn.Read(buffer) + if err != nil { + if err != io.EOF { + done <- fmt.Errorf("error reading from audio stream: %w", err) + } else { + done <- nil + } + return + } + + if n > 0 { + if !config.OnData(buffer[:n]) { + done <- nil + return + } + } + } + }() + + select { + case <-sigChan: + conn.Close() + utils.Verbose("audio stream closed by user") + return nil + case err := <-done: + utils.Verbose("audio stream ended") + return err + } +} + func (d IOSDevice) DumpSource() ([]ScreenElement, error) { return d.wdaClient.GetSourceElements() } @@ -857,6 +930,7 @@ func (d IOSDevice) SetOrientation(orientation string) error { type DeviceKitInfo struct { HTTPPort int `json:"httpPort"` StreamPort int `json:"streamPort"` + AudioPort int `json:"audioPort"` } // clickStartBroadcastButton polls for the "BroadcastUploadExtension" button, taps it, @@ -1052,6 +1126,23 @@ func (d *IOSDevice) StartDeviceKit() (*DeviceKitInfo, error) { } utils.Verbose("Port forwarding started: localhost:%d -> device:%d (H.264 stream)", localStreamPort, deviceKitStreamPort) + // Find available local port for audio forwarding. + localAudioPort, err := findAvailablePortInRange(portRangeStart, portRangeEnd) + if err != nil { + _ = httpForwarder.Stop() + _ = streamForwarder.Stop() + return nil, fmt.Errorf("failed to find available port for audio: %w", err) + } + + audioForwarder := ios.NewPortForwarder(d.ID()) + err = audioForwarder.Forward(localAudioPort, deviceKitAudioPort) + if err != nil { + _ = httpForwarder.Stop() + _ = streamForwarder.Stop() + return nil, fmt.Errorf("failed to forward audio port: %w", err) + } + utils.Verbose("Port forwarding started: localhost:%d -> device:%d (Opus stream)", localAudioPort, deviceKitAudioPort) + // Launch the main DeviceKit app utils.Verbose("Launching DeviceKit app: %s", devicekitMainAppBundleId) err = d.LaunchApp(devicekitMainAppBundleId) @@ -1059,6 +1150,7 @@ func (d *IOSDevice) StartDeviceKit() (*DeviceKitInfo, error) { // clean up port forwarders on failure _ = httpForwarder.Stop() _ = streamForwarder.Stop() + _ = audioForwarder.Stop() return nil, fmt.Errorf("failed to launch DeviceKit app: %w", err) } @@ -1076,6 +1168,7 @@ func (d *IOSDevice) StartDeviceKit() (*DeviceKitInfo, error) { // clean up port forwarders on failure _ = httpForwarder.Stop() _ = streamForwarder.Stop() + _ = audioForwarder.Stop() return nil, fmt.Errorf("failed to start agent: %w", err) } @@ -1085,6 +1178,7 @@ func (d *IOSDevice) StartDeviceKit() (*DeviceKitInfo, error) { // clean up port forwarders on failure _ = httpForwarder.Stop() _ = streamForwarder.Stop() + _ = audioForwarder.Stop() return nil, fmt.Errorf("failed to click Start Broadcast button: %w", err) } @@ -1097,6 +1191,7 @@ func (d *IOSDevice) StartDeviceKit() (*DeviceKitInfo, error) { return &DeviceKitInfo{ HTTPPort: localHTTPPort, StreamPort: localStreamPort, + AudioPort: localAudioPort, }, nil } diff --git a/devices/simulator.go b/devices/simulator.go index 591edcc..15c29dd 100644 --- a/devices/simulator.go +++ b/devices/simulator.go @@ -723,6 +723,10 @@ func (s *SimulatorDevice) StartScreenCapture(config ScreenCaptureConfig) error { return mjpegClient.StartScreenCapture(config.Format, config.OnData) } +func (s *SimulatorDevice) StartAudioCapture(config AudioCaptureConfig) error { + return fmt.Errorf("audio capture is only supported on real iOS devices") +} + type ProcessInfo struct { PID int Command string diff --git a/server/server.go b/server/server.go index 0f6fcf9..d75fce4 100644 --- a/server/server.go +++ b/server/server.go @@ -198,6 +198,16 @@ func handleJSONRPC(w http.ResponseWriter, r *http.Request) { return } + // Special case: audiocapture is streaming and has different signature + if req.Method == "audiocapture" { + err = handleAudioCapture(r, w, req.Params) + if err != nil { + log.Printf("Error in audio capture: %v", err) + sendJSONRPCError(w, req.ID, ErrCodeServerError, "Server error", err.Error()) + } + return + } + // HTTP-specific: device_boot needs extended timeout (can take up to 2 minutes) if req.Method == "device_boot" { _ = http.NewResponseController(w).SetWriteDeadline(time.Now().Add(3 * time.Minute)) @@ -934,3 +944,77 @@ func handleScreenCapture(r *http.Request, w http.ResponseWriter, params json.Raw return nil } + +func handleAudioCapture(r *http.Request, w http.ResponseWriter, params json.RawMessage) error { + _ = http.NewResponseController(w).SetWriteDeadline(time.Now().Add(10 * time.Minute)) + + var audioCaptureParams commands.AudioCaptureRequest + if err := json.Unmarshal(params, &audioCaptureParams); err != nil { + return fmt.Errorf("invalid parameters: %v", err) + } + + targetDevice, err := commands.FindDeviceOrAutoSelect(audioCaptureParams.DeviceID) + if err != nil { + return fmt.Errorf("error finding device: %w", err) + } + + if audioCaptureParams.Format == "" { + audioCaptureParams.Format = "opus+rtp" + } + + if audioCaptureParams.Format != "opus+rtp" && audioCaptureParams.Format != "opus+ogg" { + return fmt.Errorf("format must be 'opus+rtp' or 'opus+ogg' for audio capture") + } + + if targetDevice.Platform() != "ios" || targetDevice.DeviceType() != "real" { + return fmt.Errorf("audio capture is only supported on real iOS devices") + } + + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Transfer-Encoding", "chunked") + + var parser *utils.OpusFrameParser + var oggWriter *utils.OggOpusWriter + if audioCaptureParams.Format == "opus+ogg" { + var err error + oggWriter, err = utils.NewOggOpusWriter(w) + if err != nil { + return fmt.Errorf("failed to initialize ogg writer: %v", err) + } + parser = utils.NewOpusFrameParser(func(packet []byte) error { + return oggWriter.WritePacket(packet) + }) + } + + err = targetDevice.StartAudioCapture(devices.AudioCaptureConfig{ + Format: audioCaptureParams.Format, + OnData: func(data []byte) bool { + if parser != nil { + if err := parser.Write(data); err != nil { + fmt.Println("Error writing Ogg Opus data:", err) + return false + } + } else { + _, writeErr := w.Write(data) + if writeErr != nil { + fmt.Println("Error writing data:", writeErr) + return false + } + } + + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + + return true + }, + }) + + if err != nil { + return fmt.Errorf("error starting audio capture: %v", err) + } + + return nil +} diff --git a/server/websocket.go b/server/websocket.go index b180f2b..4098130 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -36,6 +36,7 @@ const ( errMsgMethodRequired = "'method' is required" errMsgTextOnly = "only text messages accepted for requests" errMsgScreencapture = "screencapture not supported over WebSocket, use HTTP /rpc endpoint" + errMsgAudiocapture = "audiocapture not supported over WebSocket, use HTTP /rpc endpoint" errTitleParseError = "Parse error" errTitleInvalidReq = "Invalid Request" errTitleMethodNotSupp = "Method not supported" @@ -188,6 +189,14 @@ func validateJSONRPCRequest(req JSONRPCRequest) *validationError { } } + if req.Method == "audiocapture" { + return &validationError{ + code: ErrCodeMethodNotFound, + message: errTitleMethodNotSupp, + data: errMsgAudiocapture, + } + } + return nil } diff --git a/utils/ogg_opus.go b/utils/ogg_opus.go new file mode 100644 index 0000000..910333f --- /dev/null +++ b/utils/ogg_opus.go @@ -0,0 +1,186 @@ +package utils + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "fmt" + "io" +) + +const ( + opusSampleRate = 48000 + opusFrameSize = 960 + opusPreskip = 312 +) + +var oggCRCTable = buildOggCRCTable() + +type OggOpusWriter struct { + w io.Writer + serial uint32 + seq uint32 + granule uint64 +} + +func NewOggOpusWriter(w io.Writer) (*OggOpusWriter, error) { + serial := randomSerial() + writer := &OggOpusWriter{ + w: w, + serial: serial, + } + + if err := writer.writeOpusHead(); err != nil { + return nil, err + } + if err := writer.writeOpusTags(); err != nil { + return nil, err + } + return writer, nil +} + +func (o *OggOpusWriter) WritePacket(packet []byte) error { + o.granule += opusFrameSize + return o.writePage(0x00, o.granule, packet) +} + +func (o *OggOpusWriter) writeOpusHead() error { + head := make([]byte, 19) + copy(head, []byte("OpusHead")) + head[8] = 1 + head[9] = 1 + binary.LittleEndian.PutUint16(head[10:], opusPreskip) + binary.LittleEndian.PutUint32(head[12:], opusSampleRate) + binary.LittleEndian.PutUint16(head[16:], 0) + head[18] = 0 + return o.writePage(0x02, 0, head) +} + +func (o *OggOpusWriter) writeOpusTags() error { + vendor := []byte("mobilecli") + packetLen := 8 + 4 + len(vendor) + 4 + packet := make([]byte, packetLen) + copy(packet, []byte("OpusTags")) + binary.LittleEndian.PutUint32(packet[8:], uint32(len(vendor))) + copy(packet[12:], vendor) + binary.LittleEndian.PutUint32(packet[12+len(vendor):], 0) + return o.writePage(0x00, 0, packet) +} + +func (o *OggOpusWriter) writePage(headerType uint8, granule uint64, packet []byte) error { + segments := make([]byte, 0, 255) + remaining := len(packet) + for remaining > 0 { + seg := remaining + if seg > 255 { + seg = 255 + } + segments = append(segments, byte(seg)) + remaining -= seg + } + if len(segments) > 255 { + return fmt.Errorf("opus packet too large for single ogg page") + } + + headerLen := 27 + len(segments) + page := make([]byte, headerLen+len(packet)) + copy(page[0:4], []byte("OggS")) + page[4] = 0 + page[5] = headerType + binary.LittleEndian.PutUint64(page[6:], granule) + binary.LittleEndian.PutUint32(page[14:], o.serial) + binary.LittleEndian.PutUint32(page[18:], o.seq) + binary.LittleEndian.PutUint32(page[22:], 0) + page[26] = byte(len(segments)) + copy(page[27:], segments) + copy(page[headerLen:], packet) + + crc := oggCRC(page) + binary.LittleEndian.PutUint32(page[22:], crc) + + o.seq++ + _, err := o.w.Write(page) + return err +} + +type OpusFrameParser struct { + buf []byte + onPacket func([]byte) error +} + +func NewOpusFrameParser(onPacket func([]byte) error) *OpusFrameParser { + return &OpusFrameParser{ + onPacket: onPacket, + } +} + +func (p *OpusFrameParser) Write(data []byte) error { + p.buf = append(p.buf, data...) + for { + if len(p.buf) < 4 { + return nil + } + packetLen := int(binary.BigEndian.Uint32(p.buf[:4])) + if packetLen <= 0 { + p.buf = p.buf[4:] + continue + } + if len(p.buf) < 4+packetLen { + return nil + } + packet := make([]byte, packetLen) + copy(packet, p.buf[4:4+packetLen]) + p.buf = p.buf[4+packetLen:] + if err := p.onPacket(packet); err != nil { + return err + } + } +} + +func randomSerial() uint32 { + var b [4]byte + if _, err := rand.Read(b[:]); err == nil { + return binary.LittleEndian.Uint32(b[:]) + } + return 0x4f505553 +} + +func buildOggCRCTable() [256]uint32 { + var table [256]uint32 + for i := 0; i < 256; i++ { + r := uint32(i) << 24 + for j := 0; j < 8; j++ { + if r&0x80000000 != 0 { + r = (r << 1) ^ 0x04C11DB7 + } else { + r <<= 1 + } + } + table[i] = r + } + return table +} + +func oggCRC(data []byte) uint32 { + var crc uint32 + for _, b := range data { + crc = (crc << 8) ^ oggCRCTable[((crc>>24)&0xFF)^uint32(b)] + } + return crc +} + +func (o *OggOpusWriter) WriteToBuffer(packet []byte) ([]byte, error) { + var buf bytes.Buffer + writer := &OggOpusWriter{ + w: &buf, + serial: o.serial, + seq: o.seq, + granule: o.granule, + } + if err := writer.WritePacket(packet); err != nil { + return nil, err + } + o.seq = writer.seq + o.granule = writer.granule + return buf.Bytes(), nil +}