Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion iotlabsshcli/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@
from iotlabsshcli.parser.open_linux_parser import main as _main


def open_a8_cli():
def open_a8_cli() -> None:
"""Entry point for the deprecated open-a8-cli command."""
deprecate_cmd(_main, "open-a8-cli", "iotlab-ssh")
37 changes: 29 additions & 8 deletions iotlabsshcli/open_linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@

import os.path
from collections import OrderedDict
from typing import Any

from iotlabsshcli.sshlib import OpenLinuxSsh


def _nodes_grouped(nodes):
def _nodes_grouped(nodes: list[str]) -> OrderedDict[str, list[str]]:
"""Group nodes per site from a list of nodes.
>>> _nodes_grouped([])
OrderedDict()
Expand Down Expand Up @@ -58,7 +59,9 @@ def _nodes_grouped(nodes):
_REMOTE_SHARED_DIR = "shared/.iotlabsshcli"


def flash(config_ssh, nodes, firmware, verbose=False):
def flash(
config_ssh: dict[str, Any], nodes: list[str], firmware: str, verbose: bool = False
) -> dict[str, Any]:
"""Flash the firmware of co-microcontroller"""
failed_hosts = []
# configure ssh and remote firmware names.
Expand All @@ -79,7 +82,7 @@ def flash(config_ssh, nodes, firmware, verbose=False):
return {"flash": result}


def reset(config_ssh, nodes, verbose=False):
def reset(config_ssh: dict[str, Any], nodes: list[str], verbose: bool = False) -> dict[str, Any]:
"""Reset co-microcontroller"""

# Configure ssh
Expand All @@ -89,7 +92,9 @@ def reset(config_ssh, nodes, verbose=False):
return {"reset": ssh.run(_RESET_CMD)}


def wait_for_boot(config_ssh, nodes, max_wait=120, verbose=False):
def wait_for_boot(
config_ssh: dict[str, Any], nodes: list[str], max_wait: int = 120, verbose: bool = False
) -> dict[str, Any]:
"""Wait for the open Linux nodes boot"""

# Configure ssh.
Expand All @@ -99,7 +104,13 @@ def wait_for_boot(config_ssh, nodes, max_wait=120, verbose=False):
return {"wait-for-boot": ssh.wait(max_wait)}


def run_cmd(config_ssh, nodes, cmd, run_on_frontend=False, verbose=False):
def run_cmd(
config_ssh: dict[str, Any],
nodes: list[str],
cmd: str,
run_on_frontend: bool = False,
verbose: bool = False,
) -> dict[str, Any]:
"""Run a command on the Linux nodes or SSH frontend servers"""

# Configure ssh.
Expand All @@ -108,7 +119,9 @@ def run_cmd(config_ssh, nodes, cmd, run_on_frontend=False, verbose=False):
return {"run-cmd": ssh.run(cmd, with_proxy=not run_on_frontend)}


def copy_file(config_ssh, nodes, file_path, verbose=False):
def copy_file(
config_ssh: dict[str, Any], nodes: list[str], file_path: str, verbose: bool = False
) -> dict[str, Any]:
"""Copy a file to SSH frontend servers"""

# Configure ssh.
Expand All @@ -120,7 +133,9 @@ def copy_file(config_ssh, nodes, file_path, verbose=False):
return {"copy-file": result}


def _get_failed_result(groups, result, run_on_frontend):
def _get_failed_result(
groups: OrderedDict[str, list[str]], result: dict[str, list[str]], run_on_frontend: bool
) -> list[str]:
"""Returns failed nodes or SSH frontend servers list.

We delete failed hosts for the next commands in the groups
Expand All @@ -139,7 +154,13 @@ def _get_failed_result(groups, result, run_on_frontend):
return failed


def run_script(config_ssh, nodes, script, run_on_frontend=False, verbose=False):
def run_script(
config_ssh: dict[str, Any],
nodes: list[str],
script: str,
run_on_frontend: bool = False,
verbose: bool = False,
) -> dict[str, Any]:
"""Run a script in background on Linux nodes or SSH frontend servers"""

# Configure ssh.
Expand Down
9 changes: 5 additions & 4 deletions iotlabsshcli/parser/open_linux_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import argparse
import sys
from typing import Any

from iotlabcli import auth, helpers, rest
from iotlabcli.helpers import deprecate_warn_cmd
Expand All @@ -31,7 +32,7 @@
import iotlabsshcli.open_linux


def parse_options():
def parse_options() -> argparse.ArgumentParser:
"""Parse command line option."""
parent_parser = argparse.ArgumentParser(add_help=False)
common.add_auth_arguments(parent_parser, False)
Expand All @@ -54,7 +55,7 @@ def parse_options():
class DeprecateHelpFormatter(argparse.HelpFormatter):
"""Add drepecated help formatter"""

def add_usage(self, usage, actions, groups, prefix=None):
def add_usage(self, usage, actions, groups, prefix=None) -> None: # type: ignore[override]
# self._prog = iotlab-ssh flash-m3 | reset-m3
old_cmd = self._prog.split()[-1]
new_cmd = old_cmd.split("-")[0]
Expand Down Expand Up @@ -142,7 +143,7 @@ def add_usage(self, usage, actions, groups, prefix=None):
return parser


def open_linux_parse_and_run(opts):
def open_linux_parse_and_run(opts: argparse.Namespace) -> dict[str, Any]:
"""Parse namespace 'opts' object."""
user, passwd = auth.get_user_credentials(opts.username, opts.password)
api = rest.Api(user, passwd)
Expand Down Expand Up @@ -194,7 +195,7 @@ def open_linux_parse_and_run(opts):
return res


def main(args=None):
def main(args: list[str] | None = None) -> None:
"""Open Linux SSH cli parser."""
args = args or sys.argv[1:] # required for easy testing.
parser = parse_options()
Expand Down
41 changes: 30 additions & 11 deletions iotlabsshcli/sshlib/open_linux_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
import asyncio
import os
import time
from typing import Any

import asyncssh


def _cleanup_result(result):
def _cleanup_result(result: dict[str, list[str]]) -> dict[str, list[str]]:
"""Remove empty list from result.

>>> _cleanup_result({ '0': [], '1': []})
Expand All @@ -49,7 +50,9 @@ def _cleanup_result(result):
return result


def _extend_result(result, new_result):
def _extend_result(
result: dict[str, list[str]], new_result: dict[str, list[str]]
) -> dict[str, list[str]]:
"""Extend result dictionnary values with new result
dictionnary values

Expand Down Expand Up @@ -100,7 +103,7 @@ def _extend_result(result, new_result):
return result


def _check_all_nodes_processed(result):
def _check_all_nodes_processed(result: dict[str, list[str]]) -> bool:
"""Verify all nodes are successful or failed.

>>> _check_all_nodes_processed({ 'saclay': [], 'grenoble': []})
Expand All @@ -126,12 +129,14 @@ def _check_all_nodes_processed(result):
class OpenLinuxSsh:
"""Implement SSH API using asyncssh."""

def __init__(self, config_ssh, groups, verbose=False):
def __init__(
self, config_ssh: dict[str, Any], groups: dict[str, list[str]], verbose: bool = False
) -> None:
self.config_ssh = config_ssh
self.groups = groups
self.verbose = verbose

def run(self, command, with_proxy=True, **kwargs):
def run(self, command: str, with_proxy: bool = True, **kwargs: Any) -> dict[str, list[str]]:
"""Run ssh command on nodes, optionally through a proxy."""
result = {"0": [], "1": []}
for site, hosts in self.groups.items():
Expand All @@ -143,7 +148,7 @@ def run(self, command, with_proxy=True, **kwargs):
result = _extend_result(result, result_cmd)
return _cleanup_result(result)

def scp(self, src, dst):
def scp(self, src: str, dst: str) -> dict[str, list[str]]:
"""Copy file to SSH frontend via SFTP."""
result = {"0": [], "1": []}
for site in self.groups:
Expand All @@ -154,7 +159,7 @@ def scp(self, src, dst):
result["1"].append(site)
return _cleanup_result(result)

def wait(self, max_wait):
def wait(self, max_wait: int) -> dict[str, list[str]]:
"""Wait for requested Linux nodes until they boot."""
result = {"0": [], "1": []}
start_time = time.time()
Expand All @@ -167,13 +172,20 @@ def wait(self, max_wait):
result = _extend_result(result, result_cmd)
return _cleanup_result(result)

def _connect_kwargs(self, timeout=10):
def _connect_kwargs(self, timeout: int = 10) -> dict[str, Any]:
kwargs = {"known_hosts": None, "connect_timeout": timeout}
if SSH_KEY:
kwargs["client_keys"] = [os.path.expanduser(SSH_KEY)]
return kwargs

async def _run_command(self, command, hosts, proxy_host=None, timeout=10, **kwargs):
async def _run_command(
self,
command: str,
hosts: list[str],
proxy_host: str | None = None,
timeout: int = 10,
**kwargs: Any,
) -> dict[str, list[str]]:
tasks = [
self._run_on_host(host, command, proxy_host=proxy_host, timeout=timeout, **kwargs)
for host in hosts
Expand All @@ -190,7 +202,14 @@ async def _run_command(self, command, hosts, proxy_host=None, timeout=10, **kwar
result["0"].append(host)
return result

async def _run_on_host(self, host, command, proxy_host=None, timeout=10, **kwargs):
async def _run_on_host(
self,
host: str,
command: str,
proxy_host: str | None = None,
timeout: int = 10,
**kwargs: Any,
) -> tuple[int, str]:
ck = self._connect_kwargs(timeout)
if proxy_host:
async with asyncssh.connect(
Expand All @@ -204,7 +223,7 @@ async def _run_on_host(self, host, command, proxy_host=None, timeout=10, **kwarg
result = await conn.run(command, **kwargs)
return result.exit_status, result.stdout or ""

async def _copy_file(self, site, src, dst):
async def _copy_file(self, site: str, src: str, dst: str) -> None:
ck = self._connect_kwargs()
async with asyncssh.connect(site, username=self.config_ssh["user"], **ck) as conn:
async with conn.start_sftp_client() as sftp:
Expand Down
13 changes: 7 additions & 6 deletions iotlabsshcli/tests/iotlabsshcli_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import sys
import unittest
from typing import Any
from unittest.mock import Mock, patch

from iotlabcli.helpers import json_dumps
Expand All @@ -34,14 +35,14 @@
class RequestRet: # pylint:disable=too-few-public-methods
"""Mock of Request return value"""

def __init__(self, status_code, content, headers=None):
def __init__(self, status_code: int, content: str, headers: Any = None) -> None:
self.status_code = status_code
self.content = content.encode("utf-8")
self.headers = headers
self.text = self.content.decode("utf-8")


def api_mock(ret=None):
def api_mock(ret: dict[str, Any] | None = None) -> Mock:
"""Return a mock of an api object
returned value for api methods will be 'ret' parameter or API_RET
"""
Expand All @@ -53,7 +54,7 @@ def api_mock(ret=None):
return api_class.return_value


def api_mock_stop():
def api_mock_stop() -> None:
"""Stop all patches started by api_mock.
Actually it stops everything but not a problem"""
patch.stopall()
Expand All @@ -62,7 +63,7 @@ def api_mock_stop():
class MainMock(unittest.TestCase):
"""Common mock needed for testing main function of parsers"""

def setUp(self):
def setUp(self) -> None:
self.api = api_mock()

patch("sys.stderr", sys.stdout).start()
Expand All @@ -75,11 +76,11 @@ def setUp(self):
"iotlabcli.auth.get_user_credentials", Mock(return_value=("username", "password"))
).start()

def get_exp(_, x, running_only=True):
def get_exp(_: Any, x: int | None, running_only: bool = True) -> int:
return x if x is not None else (123 if running_only else 234)

patch("iotlabcli.helpers.get_current_experiment", get_exp).start()

def tearDown(self):
def tearDown(self) -> None:
api_mock_stop()
patch.stopall()
Loading
Loading