diff --git a/codexctl/__init__.py b/codexctl/__init__.py index 9d445c8..d9ecb1a 100644 --- a/codexctl/__init__.py +++ b/codexctl/__init__.py @@ -7,11 +7,12 @@ import importlib.util import tempfile import shutil -import json import re -from typing import cast -from os import listdir +from typing import Any, Callable, cast +from .device import HardwareType + +from .updates import UpdateManager try: from loguru import logger @@ -23,10 +24,6 @@ "Requests is required for accessing remote files. Please install it." ) -from .device import HardwareType -from .updates import UpdateManager - - class Manager: """ Main class for codexctl @@ -39,11 +36,11 @@ def __init__(self, device: str, logger: logging.Logger) -> None: device (str): Type of device that is running the script logger (logger): Logger object """ - self.device = device - self.logger = logger - self.updater = UpdateManager(logger) + self.device: str = device + self.logger: logging.Logger = logger + self.updater: UpdateManager = UpdateManager(logger) - def call_func(self, function: str, args: dict) -> None: + def call_func(self, function: str, args: dict[str, Any]) -> None: """Runs a command based on the function name and arguments provided Args: @@ -55,9 +52,9 @@ def call_func(self, function: str, args: dict) -> None: remarkable_version = HardwareType.parse(self.device) except ValueError: hw = args.get("hardware") - remarkable_version = HardwareType.parse(hw) if hw else None + remarkable_version = cast(str, HardwareType.parse(hw)) if hw else None - version = cast(str | None, args.get("version", None)) + version = cast(Callable[[str, None], str | None], args.get)("version", None) if remarkable_version: if version == "latest": @@ -177,7 +174,8 @@ def call_func(self, function: str, args: dict) -> None: ) else: rmWeb.upload(input_paths=args["paths"], remoteFolder=args["remote"]) - + + ### Update & Version functionalities elif function in ("install", "status", "restore"): remote = False @@ -198,9 +196,9 @@ def call_func(self, function: str, args: dict) -> None: remarkable = DeviceManager( remote=remote, - address=args["address"], + address=cast(str, args["address"]), logger=self.logger, - authentication=args["password"], + authentication=cast(str, args["password"]), ) if version == "latest": @@ -285,7 +283,7 @@ def version_lookup(version: str | None) -> re.Match[str] | None: ) != os.path.abspath("updates"): if not os.path.exists("updates"): os.mkdir("updates") - shutil.move(update_file, "updates") + _ = shutil.move(update_file, "updates") update_file = get_available_version(version) made_update_folder = True # Delete at end @@ -323,7 +321,7 @@ def version_lookup(version: str | None) -> re.Match[str] | None: remarkable.install_ohma_update(update_file) if made_update_folder: # Move update file back out - shutil.move(os.listdir("updates")[0], "../") + _ = shutil.move(os.listdir("updates")[0], "../") shutil.rmtree("updates") os.chdir(orig_cwd) @@ -337,7 +335,7 @@ def main() -> None: ### Setting up the argument parser parser = argparse.ArgumentParser("Codexctl") - parser.add_argument( + _ = parser.add_argument( "--verbose", "-v", required=False, @@ -345,7 +343,7 @@ def main() -> None: action="store_true", dest="verbose", ) - parser.add_argument( + _ = parser.add_argument( "--address", "-a", required=False, @@ -353,7 +351,7 @@ def main() -> None: default=None, dest="address", ) - parser.add_argument( + _ = parser.add_argument( "--password", "-p", required=False, @@ -374,9 +372,9 @@ def main() -> None: download = subparsers.add_parser( "download", help="Download the specified version firmware file" ) - download.add_argument("version", help="Version to download") - download.add_argument("--out", "-o", help="Folder to download to", default=None) - download.add_argument( + _ = download.add_argument("version", help="Version to download") + _ = download.add_argument("--out", "-o", help="Folder to download to", default=None) + _ = download.add_argument( "--hardware", "--device", "-d", @@ -389,35 +387,35 @@ def main() -> None: backup = subparsers.add_parser( "backup", help="Download remote files to local directory" ) - backup.add_argument( + _ = backup.add_argument( "-r", "--remote", help="Remote directory to backup. Defaults to download folder", default="", dest="remote", ) - backup.add_argument( + _ = backup.add_argument( "-l", "--local", help="Local directory to backup to. Defaults to download folder", default="./", dest="local", ) - backup.add_argument( + _ = backup.add_argument( "-R", "--no-recursion", help="Disables recursively backup remote directory", action="store_true", dest="no_recursion", ) - backup.add_argument( + _ = backup.add_argument( "-O", "--no-overwrite", help="Disables overwrite", action="store_true", dest="no_overwrite", ) - backup.add_argument( + _ = backup.add_argument( "-i", "--incremental", help="Overwrite out-of-date files only", @@ -428,40 +426,48 @@ def main() -> None: cat = subparsers.add_parser( "cat", help="Cat the contents of a file inside a firmwareimage" ) - cat.add_argument("file", help="Path to update file to cat", default=None) - cat.add_argument("target_path", help="Path inside the image to list", default=None) + _ = cat.add_argument("file", help="Path to update file to cat", default=None) + _ = cat.add_argument( + "target_path", help="Path inside the image to list", default=None + ) ### Ls subcommand ls = subparsers.add_parser("ls", help="List files inside a firmware image") - ls.add_argument("file", help="Path to update file to extract", default=None) - ls.add_argument("target_path", help="Path inside the image to list", default=None) + _ = ls.add_argument("file", help="Path to update file to extract", default=None) + _ = ls.add_argument( + "target_path", help="Path inside the image to list", default=None + ) ### Extract subcommand extract = subparsers.add_parser( "extract", help="Extract the specified version update file" ) - extract.add_argument("file", help="Path to update file to extract", default=None) - extract.add_argument("--out", help="Folder to extract to", default=None, dest="out") + _ = extract.add_argument( + "file", help="Path to update file to extract", default=None + ) + _ = extract.add_argument( + "--out", help="Folder to extract to", default=None, dest="out" + ) ### Mount subcommand mount = subparsers.add_parser( "mount", help="Mount the specified version firmware filesystem" ) - mount.add_argument( + _ = mount.add_argument( "filesystem", help="Path to version firmware filesystem to extract", default=None, ) - mount.add_argument("--out", help="Folder to mount to", default=None) + _ = mount.add_argument("--out", help="Folder to mount to", default=None) ### Upload subcommand upload = subparsers.add_parser( "upload", help="Upload folder/files to device (pdf only)" ) - upload.add_argument( + _ = upload.add_argument( "paths", help="Path to file(s)/folder to upload", default=None, nargs="+" ) - upload.add_argument( + _ = upload.add_argument( "-r", "--remote", help="Remote directory to upload to. Defaults to root folder", @@ -470,12 +476,12 @@ def main() -> None: ) ### Status subcommand - subparsers.add_parser( + _ = subparsers.add_parser( "status", help="Get the current version of the device and other information" ) ### Restore subcommand - subparsers.add_parser( + _ = subparsers.add_parser( "restore", help="Restores to previous version installed on device" ) diff --git a/codexctl/analysis.py b/codexctl/analysis.py index 8e8fe85..9e3ada7 100644 --- a/codexctl/analysis.py +++ b/codexctl/analysis.py @@ -1,12 +1,12 @@ import ext4 -import warnings +import warnings import errno from remarkable_update_image import UpdateImage from remarkable_update_image import UpdateImageSignatureException -def get_update_image(file: str): +def get_update_image(file: str) -> tuple[UpdateImage, ext4.Volume]: """Extracts files from an update image (<3.11 currently)""" image = UpdateImage(file) diff --git a/codexctl/device.py b/codexctl/device.py index 71e7e45..49e5beb 100644 --- a/codexctl/device.py +++ b/codexctl/device.py @@ -7,6 +7,8 @@ import os import time +from typing import cast + from .server import startUpdate try: @@ -74,7 +76,11 @@ def toltec_type(self): class DeviceManager: def __init__( - self, logger=None, remote=False, address=None, authentication=None + self, + logger: logging.Logger | None = None, + remote: bool = False, + address: str | None = None, + authentication: str | None = None, ) -> None: """Initializes the DeviceManager for codexctl @@ -84,13 +90,10 @@ def __init__( logger (logger, optional): Logger object for logging. Defaults to None. Authentication (str, optional): Authentication method. Defaults to None. """ - self.logger = logger - self.address = address - self.authentication = authentication - self.client = None - - if self.logger is None: - self.logger = logging + self.logger: logging.Logger = logger or cast(logging.Logger, logging) # pyright:ignore [reportInvalidCast] + self.address: str | None = address + self.authentication: str | None = authentication + self.client: paramiko.client.SSHClient | None = None if remote: self.client = self.connect_to_device( @@ -107,16 +110,15 @@ def __init__( with open("/sys/devices/soc0/machine") as file: machine_contents = file.read().strip("\n") - self.hardware = HardwareType.parse(machine_contents) + self.hardware = HardwareType.parse(machine_contents) - def get_host_address(self) -> list[str] | list | None: # Interaction required + def get_host_address(self) -> str: # Interaction required """Gets the IP address of the host machine Returns: str | None: IP address of the host machine, or None if not found """ - - possible_ips = [] + possible_ips: list[str] = [] try: for interface, snics in psutil.net_if_addrs().items(): self.logger.debug(f"New interface found: {interface}") @@ -170,7 +172,7 @@ def get_remarkable_address(self) -> str: print(f"Error: Device {remote_ip} is not reachable. Please try again.") - def check_is_address_reachable(self, remote_ip="10.11.99.1") -> bool: + def check_is_address_reachable(self, remote_ip: str | None = "10.11.99.1") -> bool: """Checks if the given IP address is reachable over SSH Args: @@ -194,7 +196,7 @@ def check_is_address_reachable(self, remote_ip="10.11.99.1") -> bool: return False def connect_to_device( - self, remote_address=None, authentication=None + self, remote_address: str | None = None, authentication: str | None = None ) -> paramiko.client.SSHClient: """Connects to the device using the given IP address @@ -205,7 +207,6 @@ def connect_to_device( Returns: paramiko.client.SSHClient: SSH client object for the device. """ - if remote_address is None: remote_address = self.get_remarkable_address() self.address = remote_address # For future reference @@ -274,7 +275,7 @@ def connect_to_device( return client - def get_device_status(self) -> tuple[str | None, str, str]: + def get_device_status(self) -> tuple[str, bool, str, str | None]: """Gets the status of the device Returns: @@ -282,6 +283,7 @@ def get_device_status(self) -> tuple[str | None, str, str]: """ old_update_engine = True + version_id: str | None = None if self.client: self.logger.debug("Connecting to FTP") ftp = self.client.open_sftp() @@ -303,10 +305,10 @@ def get_device_status(self) -> tuple[str | None, str, str]: old_update_engine = False with ftp.file("/etc/version") as file: - version_id = file.read().decode("utf-8").strip("\n") + version_id = cast(str, file.read().decode("utf-8").strip("\n")) with ftp.file("/home/root/.config/remarkable/xochitl.conf") as file: - beta_contents = file.read().decode("utf-8") + beta_contents = cast(str, file.read().decode("utf-8")) else: if os.path.exists("/usr/share/remarkable/update.conf"): @@ -376,12 +378,12 @@ def set_server_config(self, contents: str, server_host_name: str) -> str: return converted - def edit_update_conf(self, server_ip: str, server_port: str) -> bool: + def edit_update_conf(self, server_ip: str, server_port: int) -> bool: """Edits the update.conf file to point to the given server IP and port Args: server_ip (str): IP of update server - server_port (str): Port of update service + server_port (int): Port of update service Returns: bool: True if successful, False otherwise @@ -399,7 +401,7 @@ def edit_update_conf(self, server_ip: str, server_port: str) -> bool: ) with open("/usr/share/remarkable/update.conf", "w") as file: - file.write(modified_conf_version) + _ = file.write(modified_conf_version) return True @@ -422,8 +424,21 @@ def edit_update_conf(self, server_ip: str, server_port: str) -> bool: def restore_previous_version(self) -> None: """Restores the previous version of the device""" - - RESTORE_CODE = """/sbin/fw_setenv "upgrade_available" "1" + + if self.hardware == HardwareType.RMPP: + RESTORE_CODE = """#!/bin/bash +OLDPART=$(< /sys/devices/platform/lpgpr/root_part) +if [[ $OLDPART == "a" ]]; then + NEWPART="b" +else + NEWPART="a" +fi +echo "new: ${NEWPART}" +echo "fallback: ${OLDPART}" +echo $NEWPART > /sys/devices/platform/lpgpr/root_part +""" + else: + RESTORE_CODE = """/sbin/fw_setenv "upgrade_available" "1" /sbin/fw_setenv "bootcount" "0" OLDPART=$(/sbin/fw_printenv -n active_partition) @@ -436,21 +451,8 @@ def restore_previous_version(self) -> None: echo "fallback: ${OLDPART}" /sbin/fw_setenv "fallback_partition" "${OLDPART}" -/sbin/fw_setenv "active_partition" "${NEWPART}\"""" - - if self.hardware == HardwareType.RMPP: - RESTORE_CODE = """#!/bin/bash -OLDPART=$(< /sys/devices/platform/lpgpr/root_part) -if [[ $OLDPART == "a" ]]; then - NEWPART="b" -else - NEWPART="a" -fi -echo "new: ${NEWPART}" -echo "fallback: ${OLDPART}" -echo $NEWPART > /sys/devices/platform/lpgpr/root_part +/sbin/fw_setenv "active_partition" "${NEWPART}\" """ - if self.client: self.logger.debug("Connecting to FTP") ftp = self.client.open_sftp() @@ -465,12 +467,12 @@ def restore_previous_version(self) -> None: self.client.exec_command("bash /tmp/restore.sh") else: with open("/tmp/restore.sh", "w") as file: - file.write(RESTORE_CODE) + _ = file.write(RESTORE_CODE) self.logger.debug("Setting permissions and running restore.sh") - os.system("chmod +x /tmp/restore.sh") - os.system("/tmp/restore.sh") + _ = os.system("chmod +x /tmp/restore.sh") + _ = os.system("/tmp/restore.sh") self.logger.debug("Restore script ran") @@ -494,10 +496,10 @@ def reboot_device(self) -> None: else: with open("/tmp/reboot.sh", "w") as file: - file.write(REBOOT_CODE) + _ = file.write(REBOOT_CODE) self.logger.debug("Running reboot.sh") - os.system("sh /tmp/reboot.sh") + _ = os.system("sh /tmp/reboot.sh") self.logger.debug("Device rebooted") @@ -603,9 +605,11 @@ def install_sw_update(self, version_file: str) -> None: ) print("Update complete and device rebooting") - os.system("reboot") + _ = os.system("reboot") - def install_ohma_update(self, version_available: dict) -> None: + def install_ohma_update( + self, version_available: dict[str, tuple[str, str]] + ) -> None: """Installs version from update folder on the device Args: @@ -692,7 +696,7 @@ def install_ohma_update(self, version_available: dict) -> None: else: print("Enabling update service") - subprocess.run( + _ = subprocess.run( ["/bin/systemctl", "start", "update-engine"], text=True, check=True, @@ -716,7 +720,7 @@ def install_ohma_update(self, version_available: dict) -> None: ) print("Update complete and device rebooting") - os.system("reboot") + _ = os.system("reboot") @staticmethod def output_put_progress(transferred: int, toBeTransferred: int) -> None: diff --git a/codexctl/server.py b/codexctl/server.py index 766d8f7..ca46361 100644 --- a/codexctl/server.py +++ b/codexctl/server.py @@ -38,10 +38,10 @@ """ -def getupdateinfo(platform, version, update_name): +def getupdateinfo(update_name: str) -> tuple[str, str, int]: full_path = os.path.join("updates", update_name) - update_size = str(os.path.getsize(full_path)) + update_size = os.path.getsize(full_path) BUF_SIZE = 8192 @@ -59,7 +59,7 @@ def getupdateinfo(platform, version, update_name): return (update_sha1, update_sha256, update_size) -def get_available_version(version): +def get_available_version(version: str): available_versions = scanUpdates() for device, ids in available_versions.items(): @@ -69,9 +69,9 @@ def get_available_version(version): return available_version -def scanUpdates(): +def scanUpdates() -> dict[str, tuple[str, str]]: files = os.listdir("updates") - versions = {} + versions: dict[str, tuple[str, str]] = {} for f in files: p = f.split("_") @@ -95,7 +95,7 @@ def scanUpdates(): class MySimpleHTTPRequestHandler(SimpleHTTPRequestHandler): def do_POST(self): - length = int(self.headers.get("Content-Length")) + length = int(self.headers.get("Content-Length") or 0) body = self.rfile.read(length).decode("utf-8") # print(body) print("Updating...") @@ -105,21 +105,23 @@ def do_POST(self): # check for update if updatecheck_node is not None: version = xml.attrib["version"] - platform = xml.find("os").attrib["platform"] + os = xml.find("os") + if os is None: + raise Exception("os tag missing from results") + + platform = os.attrib["platform"] print("requested: ", version) print("platform: ", platform) version, update_name = available_versions[platform] - update_sha1, update_sha256, update_size = getupdateinfo( - platform, version, update_name - ) + update_sha1, update_sha256, update_size = getupdateinfo(update_name) params = { "version": version, "update_name": f"updates/{update_name}", "update_sha1": update_sha1, "update_sha256": update_sha256, - "update_size": update_size, + "update_size": str(update_size), "codebase_url": host_url, } @@ -128,10 +130,13 @@ def do_POST(self): # print(response) self.send_response(200) self.end_headers() - self.wfile.write(response.encode()) + _ = self.wfile.write(response.encode()) return event_node = xml.find("app/event") + if event_node is None: + raise Exception("app/event tag missing from results") + event_type = int(event_node.attrib["eventtype"]) event_result = int(event_node.attrib["eventresult"]) @@ -148,11 +153,11 @@ def do_POST(self): print(response_ok) self.send_response(200) self.end_headers() - self.wfile.write(response_ok.encode()) + _ = self.wfile.write(response_ok.encode()) return -def startUpdate(versionsGiven, host, port=8080): +def startUpdate(versionsGiven: dict[str, tuple[str, str]], host: str, port: int = 8080): global available_versions global host_url # I am aware globals are generally bad practice, but this is a quick and dirty solution diff --git a/codexctl/sync.py b/codexctl/sync.py index 5d80d37..72cd67d 100644 --- a/codexctl/sync.py +++ b/codexctl/sync.py @@ -5,22 +5,31 @@ import requests +from typing import IO, Any, cast -class RmWebInterfaceAPI: # TODO: Add docstrings - def __init__(self, BASE="http://10.11.99.1/", logger=None): - self.logger = logger - if self.logger is None: - self.logger = logging +class RmWebInterfaceAPI: # TODO: Add docstrings + def __init__( + self, BASE: str = "http://10.11.99.1/", logger: logging.Logger | None = None + ): + self.logger: logging.Logger = logger or cast(logging.Logger, logging) # pyright:ignore [reportInvalidCast] - self.BASE = BASE - self.ID_ATTRIBUTE = "ID" - self.NAME_ATTRIBUTE = "VissibleName" - self.MTIME_ATTRIBUTE = "ModifiedClient" + self.BASE: str = BASE + self.ID_ATTRIBUTE: str = "ID" + self.NAME_ATTRIBUTE: str = "VissibleName" + self.MTIME_ATTRIBUTE: str = "ModifiedClient" self.logger.debug(f"Base is: {BASE}") - def __POST(self, endpoint, data={}, fileUpload=False): + def __POST( + self, + endpoint: str, + data: dict[str, str | IO[bytes]] | None = None, + fileUpload: bool = False, + ) -> bytes | Any: + if data is None: + data = {} + try: logging.debug( f"Sending POST request to {self.BASE + endpoint} with data {data}" @@ -43,11 +52,19 @@ def __POST(self, endpoint, data={}, fileUpload=False): return None def __get_documents_recursive( - self, folderId="", currentLocation="", currentDocuments=[] + self, + folderId: str = "", + currentLocation: str = "", + currentDocuments: list[dict[str, Any]] | None = None, ): data = self.__POST(f"documents/{folderId}") + if not isinstance(data, list): + raise Exception("Unexpected result from server") - for item in data: + if currentDocuments is None: + currentDocuments = [] + + for item in cast(list[dict[str, Any]], data): self.logger.debug(f"Checking item: {item}") if "fileType" in item: @@ -57,49 +74,51 @@ def __get_documents_recursive( self.logger.debug( f"Getting documents over {item[self.ID_ATTRIBUTE]}, current location is {currentLocation}/{item[self.NAME_ATTRIBUTE]}" ) - self.__get_documents_recursive( - item[self.ID_ATTRIBUTE], + _ = self.__get_documents_recursive( + cast(str, item[self.ID_ATTRIBUTE]), f"{currentLocation}/{item[self.NAME_ATTRIBUTE]}", currentDocuments, ) return currentDocuments - def __get_folder_id(self, folderName, _from=""): + def __get_folder_id(self, folderName: str, _from: str = "") -> str | None: results = self.__POST(f"documents/{_from}") if results is None: return None + if not isinstance(results, list): + raise Exception("Unexpected result from server") + results.reverse() # We only want folders - for data in results: + for data in cast(list[dict[str, Any]], results): self.logger.debug(f"Folder: {data}") if "fileType" in data: return None - if data[self.NAME_ATTRIBUTE].strip() == folderName.strip(): - return data[self.ID_ATTRIBUTE] + identifier = cast(str, data[self.ID_ATTRIBUTE]) + if cast(str, data[self.NAME_ATTRIBUTE]).strip() == folderName.strip(): + return identifier - self.logger.debug( - f"Getting folders over {folderName}, {data[self.ID_ATTRIBUTE]}" - ) + self.logger.debug(f"Getting folders over {folderName}, {identifier}") - recursiveResults = self.__get_folder_id(folderName, data[self.ID_ATTRIBUTE]) - if recursiveResults is None: - continue - else: + recursiveResults = self.__get_folder_id(folderName, identifier) + if recursiveResults is not None: return recursiveResults - def __get_docs(self, folderName="", recursive=True): + def __get_docs( + self, folderName: str = "", recursive: bool = True + ) -> list[dict[str, Any]]: folderId = "" if folderName: folderId = self.__get_folder_id(folderName) if folderId is None: - return {} + return [] if recursive: self.logger.debug(f"Calling recursive function on {folderName}") @@ -109,13 +128,23 @@ def __get_docs(self, folderName="", recursive=True): data = self.__POST(f"documents/{folderId}") + if not isinstance(data, list): + raise Exception("Unexpected result from server") + + data = cast(list[dict[str, Any]], data) for item in data: item["location"] = "" return [item for item in data if "fileType" in item] - def download(self, document, location="", overwrite=False, incremental=False): - filename = document[self.NAME_ATTRIBUTE] + def download( + self, + document: dict[str, Any], + location: str = "", + overwrite: bool = False, + incremental: bool = False, + ): + filename = cast(str, document[self.NAME_ATTRIBUTE]) if "/" in filename: filename = filename.replace("/", "_") @@ -146,7 +175,7 @@ def download(self, document, location="", overwrite=False, incremental=False): return False with open(fileLocation, "wb") as outFile: - outFile.write(binaryData) + _ = outFile.write(binaryData) return True @@ -154,15 +183,15 @@ def download(self, document, location="", overwrite=False, incremental=False): print(f"Error trying to download {filename}: {error}") return False - def __is_newer(self, document, fileLocation): - remote_ts = document[self.MTIME_ATTRIBUTE] + def __is_newer(self, document: dict[str, Any], fileLocation: str): + remote_ts = cast(str, document[self.MTIME_ATTRIBUTE]) local_mtime = os.path.getmtime(fileLocation) local_ts = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(local_mtime)) return remote_ts > local_ts - def upload(self, input_paths, remoteFolder): + def upload(self, input_paths: list[str], remoteFolder: str): folderId = "" if remoteFolder: folderId = self.__get_folder_id(remoteFolder) @@ -170,9 +199,10 @@ def upload(self, input_paths, remoteFolder): if folderId is None: raise SystemError(f"Error: Folder {remoteFolder} does not exist!") - self.__POST(f"documents/{folderId}") # Setting up for upload... + _ = self.__POST(f"documents/{folderId}") # Setting up for upload... - errors, documents = [], [] + errors: list[str] = [] + documents: list[str] = [] for document in input_paths: # This needs improvement... if os.path.isdir(document): @@ -209,15 +239,15 @@ def upload(self, input_paths, remoteFolder): if len(errors) > 0: print("The following files failed to upload: " + ",".join(errors)) - print(f"Done! {len(documents)-len(errors)} files were uploaded.") + print(f"Done! {len(documents) - len(errors)} files were uploaded.") def sync( self, - localFolder, - remoteFolder="", - overwrite=False, - incremental=False, - recursive=True, + localFolder: str, + remoteFolder: str = "", + overwrite: bool = False, + incremental: bool = False, + recursive: bool = True, ): count = 0 @@ -227,17 +257,17 @@ def sync( documents = self.__get_docs(remoteFolder, recursive) - if documents == {}: + if not documents: print("No documents were found!") - - else: - for doc in documents: - self.logger.debug(f"Processing {doc}") - count += 1 - self.download( - document=doc, - location=f"{localFolder}/{doc['location']}", - overwrite=overwrite, - incremental=incremental, - ) - print(f"Done! {count} files were exported.") + return + + for doc in documents: + self.logger.debug(f"Processing {doc}") + count += 1 + _ = self.download( + document=doc, + location=f"{localFolder}/{doc['location']}", + overwrite=overwrite, + incremental=incremental, + ) + print(f"Done! {count} files were exported.") diff --git a/codexctl/updates.py b/codexctl/updates.py index 9d20377..7cbcf28 100644 --- a/codexctl/updates.py +++ b/codexctl/updates.py @@ -1,6 +1,5 @@ import os import requests -import uuid import sys import json import hashlib @@ -8,25 +7,25 @@ from pathlib import Path from datetime import datetime - -import xml.etree.ElementTree as ET +from typing import cast from .device import HardwareType class UpdateManager: - def __init__(self, logger=None) -> None: + def __init__(self, logger: logging.Logger | None = None) -> None: """Manager for downloading update versions Args: logger (logger, optional): Logger object for logging. Defaults to None. """ - self.logger = logger - - if self.logger is None: - self.logger = logging + self.logger: logging.Logger = logger or cast(logging.Logger, logging) # pyright:ignore [reportInvalidCast] + self.remarkablepp_versions: dict[str, list[str]] + self.remarkable2_versions: dict[str, list[str]] + self.remarkable1_versions: dict[str, list[str]] + self.external_provider_url: str ( self.remarkablepp_versions, self.remarkable2_versions, @@ -34,7 +33,9 @@ def __init__(self, logger=None) -> None: self.external_provider_url, ) = self.get_remarkable_versions() - def get_remarkable_versions(self) -> tuple[dict, dict, dict, str, str]: + def get_remarkable_versions( + self, + ) -> tuple[dict[str, list[str]], dict[str, list[str]], dict[str, list[str]], str]: """Gets the avaliable versions for the device, by checking the local version-ids.json file and then updating it if necessary Returns: @@ -48,7 +49,7 @@ def get_remarkable_versions(self) -> tuple[dict, dict, dict, str, str]: else: if os.name == "nt": # Windows - folder_location = os.getenv("APPDATA") + "/codexctl" + folder_location = (os.getenv("APPDATA") or "") + "/codexctl" elif os.name in ("posix", "darwin"): # Linux or MacOS folder_location = os.path.expanduser("~/.config/codexctl") else: @@ -65,11 +66,14 @@ def get_remarkable_versions(self) -> tuple[dict, dict, dict, str, str]: try: with open(file_location) as f: - contents = json.load(f) + contents = json.load(f) # pyright:ignore [reportAny] + if not isinstance(contents, dict): + raise ValueError() + except ValueError: raise SystemError( f"Version-ids.json @ {file_location} is corrupted! Please delete it and try again. Also, PLEASE open an issue on the repo showing the contents of the file." - ) + ) from None if ( int(datetime.now().timestamp()) - contents["last-updated"] @@ -78,15 +82,23 @@ def get_remarkable_versions(self) -> tuple[dict, dict, dict, str, str]: self.update_version_ids(file_location) with open(file_location) as f: - contents = json.load(f) + try: + contents = json.load(f) # pyright:ignore [reportAny] + if not isinstance(contents, dict): + raise ValueError() + + except ValueError: + raise SystemError( + f"Version-ids.json @ {file_location} is corrupted! Please delete it and try again. Also, PLEASE open an issue on the repo showing the contents of the file." + ) from None self.logger.debug(f"Version ids contents are {contents}") return ( - contents["remarkablepp"], - contents["remarkable2"], - contents["remarkable1"], - contents["external-provider-url"], + cast(dict[str, list[str]], contents["remarkablepp"]), + cast(dict[str, list[str]], contents["remarkable2"]), + cast(dict[str, list[str]], contents["remarkable1"]), + cast(str, contents["external-provider-url"]), ) def update_version_ids(self, location: str) -> None: @@ -98,22 +110,23 @@ def update_version_ids(self, location: str) -> None: Raises: SystemExit: If the file cannot be updated """ - with open(location, "w", newline="\n") as f: - try: + try: + with open(location, "w", newline="\n") as f: self.logger.debug("Downloading version-ids.json") - contents = requests.get( + contents = requests.get( # pyright:ignore [reportAny] "https://raw.githubusercontent.com/Jayy001/codexctl/main/data/version-ids.json" ).json() json.dump(contents, f, indent=4) - f.write("\n") - except requests.exceptions.Timeout: - raise SystemExit( - "Connection timed out while downloading version-ids.json! Do you have an internet connection?" - ) - except Exception as error: - raise SystemExit( - f"Unknown error while downloading version-ids.json! {error}" - ) + _ = f.write("\n") + except requests.exceptions.Timeout: + raise SystemExit( + "Connection timed out while downloading version-ids.json! Do you have an internet connection?" + ) from None + + except Exception as error: + raise SystemExit( + f"Unknown error while downloading version-ids.json! {error}" + ) from error def get_latest_version(self, hardware_type: HardwareType) -> str: """Gets the latest version available for the device @@ -132,7 +145,7 @@ def get_latest_version(self, hardware_type: HardwareType) -> str: case HardwareType.RMPP: versions = self.remarkablepp_versions - return self.__max_version(versions.keys()) + return self.__max_version(list(versions.keys())) def get_toltec_version(self, hardware_type: HardwareType) -> str: """Gets the latest version available toltec for the device @@ -232,58 +245,8 @@ def download_version( file_url, file_name, download_folder, version_checksum ) - def __generate_xml_data(self) -> str: - """Generates and returns XML data for the update request""" - params = { - "installsource": "scheduler", - "requestid": str(uuid.uuid4()), - "sessionid": str(uuid.uuid4()), - "machineid": "00".zfill(32), - "oem": "RM100-753-12345", - "appid": "98DA7DF2-4E3E-4744-9DE6-EC931886ABAB", - "bootid": str(uuid.uuid4()), - "current": "3.2.3.1595", - "group": "Prod", - "platform": "reMarkable2", - } - - return """ - - - - - -""".format(**params) - - def __parse_response(self, resp: str) -> tuple[str, str, str] | None: - """Parses the response from the update server and returns the file name, uri, and version if an update is available - - Args: - resp (str): Response from the server - - Returns: - tuple[str, str, str] | None: File name, uri, and version if an update is available, None otherwise - """ - xml_data = ET.fromstring(resp) - - if "noupdate" in resp or xml_data is None: - return None - - file_name = xml_data.find("app/updatecheck/manifest/packages/package").attrib[ - "name" - ] - file_uri = ( - f"{xml_data.find('app/updatecheck/urls/url').attrib['codebase']}{file_name}" - ) - file_version = xml_data.find("app/updatecheck/manifest").attrib["version"] - - self.logger.debug( - f"File version is {file_version}, file uri is {file_uri}, file name is {file_name}" - ) - return file_version, file_uri, file_name - def __download_version_file( - self, uri: str, name: str, download_folder: str, checksum: str + self, uri: str, name: str, download_folder: str | Path, checksum: str ) -> str | None: """Downloads the version file from the server and checks the checksum @@ -305,7 +268,7 @@ def __download_version_file( self.logger.debug(f"Downloading {name} from {uri} to {download_folder}") try: - file_length = int(file_length) + file_length = int(file_length or 0) if int(file_length) < 10000000: # 10MB, invalid version file self.logger.error( @@ -324,13 +287,14 @@ def __download_version_file( with open(filename, "wb") as out_file: dl = 0 - for data in response.iter_content(chunk_size=4096): + data: bytes + for data in response.iter_content(chunk_size=4096): # pyright:ignore [reportAny] dl += len(data) - out_file.write(data) + _ = out_file.write(data) if sys.stdout.isatty(): done = int(50 * dl / file_length) - sys.stdout.write("\r[%s%s]" % ("=" * done, " " * (50 - done))) - sys.stdout.flush() + _ = sys.stdout.write("\r[%s%s]" % ("=" * done, " " * (50 - done))) + _ = sys.stdout.flush() if sys.stdout.isatty(): print(end="\r\n") @@ -350,7 +314,7 @@ def __download_version_file( return filename @staticmethod - def __max_version(versions: list) -> str: + def __max_version(versions: list[str]) -> str: """Returns the highest avaliable version from a list with semantic versioning""" return sorted(versions, key=lambda v: tuple(map(int, v.split("."))))[-1]