Skip to content

Commit 5eacedf

Browse files
lint and test fixes
1 parent bac0fd2 commit 5eacedf

File tree

6 files changed

+280
-124
lines changed

6 files changed

+280
-124
lines changed

src/codeflare_sdk/common/utils/generate_cert.py

Lines changed: 63 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def generate_tls_cert(cluster_name, namespace, days=30, force_regenerate=False):
170170
- AuthorityKeyIdentifier
171171
172172
Files are created with restricted permissions (0600) for security.
173-
173+
174174
Certificates are stored in a user-private directory:
175175
- Default: ~/.local/share/codeflare/tls/{cluster_name}-{namespace}/
176176
- Override via CODEFLARE_TLS_DIR environment variable
@@ -194,27 +194,27 @@ def generate_tls_cert(cluster_name, namespace, days=30, force_regenerate=False):
194194
Raises:
195195
Exception:
196196
If an error occurs while retrieving the CA secret.
197-
197+
198198
Example:
199199
# Normal generation
200200
generate_tls_cert("my-cluster", "default")
201-
201+
202202
# Force regeneration if CA was rotated
203203
generate_tls_cert("my-cluster", "default", force_regenerate=True)
204204
"""
205205
tls_base_dir = _get_tls_base_dir()
206206
tls_dir = tls_base_dir / f"{cluster_name}-{namespace}"
207-
207+
208208
# Check if certificates already exist and skip if not forcing regeneration
209209
if not force_regenerate and tls_dir.exists():
210210
ca_crt = tls_dir / "ca.crt"
211211
tls_crt = tls_dir / "tls.crt"
212212
tls_key = tls_dir / "tls.key"
213-
213+
214214
if ca_crt.exists() and tls_crt.exists() and tls_key.exists():
215215
# Certificates already exist, no need to regenerate
216216
return
217-
217+
218218
# Create directory with secure permissions (including parent directories)
219219
tls_dir.mkdir(mode=0o700, parents=True, exist_ok=True)
220220
tls_dir = str(tls_dir)
@@ -254,7 +254,6 @@ def generate_tls_cert(cluster_name, namespace, days=30, force_regenerate=False):
254254
f.write(base64.b64decode(ca_cert).decode("utf-8"))
255255
os.chmod(ca_crt_path, stat.S_IRUSR | stat.S_IWUSR) # Set permissions to 0600
256256

257-
258257
# Generate tls.key and signed tls.cert locally for ray client
259258
# Similar to running these commands:
260259
# openssl req -nodes -newkey rsa:3072 -keyout ${TLSDIR}/tls.key -out ${TLSDIR}/tls.csr -subj '/CN=local'
@@ -325,21 +324,25 @@ def generate_tls_cert(cluster_name, namespace, days=30, force_regenerate=False):
325324
critical=True,
326325
)
327326
.add_extension(
328-
x509.ExtendedKeyUsage([
329-
ExtendedKeyUsageOID.SERVER_AUTH,
330-
ExtendedKeyUsageOID.CLIENT_AUTH, # For mTLS support
331-
]),
327+
x509.ExtendedKeyUsage(
328+
[
329+
ExtendedKeyUsageOID.SERVER_AUTH,
330+
ExtendedKeyUsageOID.CLIENT_AUTH, # For mTLS support
331+
]
332+
),
332333
critical=True,
333334
)
334335
.add_extension(
335-
x509.SubjectAlternativeName([
336-
x509.DNSName("localhost"),
337-
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
338-
x509.IPAddress(ipaddress.IPv6Address("::1")),
339-
x509.DNSName(head_svc_name),
340-
x509.DNSName(service_dns),
341-
x509.DNSName(service_dns_cluster_local),
342-
]),
336+
x509.SubjectAlternativeName(
337+
[
338+
x509.DNSName("localhost"),
339+
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
340+
x509.IPAddress(ipaddress.IPv6Address("::1")),
341+
x509.DNSName(head_svc_name),
342+
x509.DNSName(service_dns),
343+
x509.DNSName(service_dns_cluster_local),
344+
]
345+
),
343346
critical=False,
344347
)
345348
.add_extension(
@@ -362,7 +365,7 @@ def generate_tls_cert(cluster_name, namespace, days=30, force_regenerate=False):
362365
with open(tls_crt_path, "w") as f:
363366
f.write(tls_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8"))
364367
os.chmod(tls_crt_path, stat.S_IRUSR | stat.S_IWUSR) # Set permissions to 0600
365-
368+
366369
del ca_key, ca_private_key
367370
try:
368371
del secret
@@ -467,27 +470,31 @@ def list_tls_certificates():
467470
created = datetime.datetime.fromtimestamp(stat_info.st_ctime)
468471

469472
# Calculate total size
470-
total_size = sum(f.stat().st_size for f in cert_dir.rglob('*') if f.is_file())
473+
total_size = sum(
474+
f.stat().st_size for f in cert_dir.rglob("*") if f.is_file()
475+
)
471476

472477
# Try to read certificate expiry
473478
cert_expiry = None
474479
tls_cert_path = cert_dir / "tls.crt"
475480
if tls_cert_path.exists():
476481
try:
477-
with open(tls_cert_path, 'rb') as f:
482+
with open(tls_cert_path, "rb") as f:
478483
cert = x509.load_pem_x509_certificate(f.read())
479484
cert_expiry = cert.not_valid_after_utc
480485
except Exception:
481486
cert_expiry = None
482487

483-
certificates.append({
484-
'cluster_name': cluster_name,
485-
'namespace': namespace,
486-
'path': str(cert_dir),
487-
'created': created,
488-
'size': total_size,
489-
'cert_expiry': cert_expiry,
490-
})
488+
certificates.append(
489+
{
490+
"cluster_name": cluster_name,
491+
"namespace": namespace,
492+
"path": str(cert_dir),
493+
"created": created,
494+
"size": total_size,
495+
"cert_expiry": cert_expiry,
496+
}
497+
)
491498

492499
return certificates
493500

@@ -508,7 +515,7 @@ def cleanup_expired_certificates(dry_run=True):
508515
>>> # Check what would be deleted
509516
>>> expired = cleanup_expired_certificates(dry_run=True)
510517
>>> print(f"Found {len(expired)} expired certificates")
511-
>>>
518+
>>>
512519
>>> # Actually delete them
513520
>>> cleanup_expired_certificates(dry_run=False)
514521
"""
@@ -520,11 +527,11 @@ def cleanup_expired_certificates(dry_run=True):
520527
certificates = list_tls_certificates()
521528

522529
for cert_info in certificates:
523-
if cert_info['cert_expiry'] and cert_info['cert_expiry'] < now:
524-
expired_certs.append(cert_info['path'])
530+
if cert_info["cert_expiry"] and cert_info["cert_expiry"] < now:
531+
expired_certs.append(cert_info["path"])
525532

526533
if not dry_run:
527-
cert_dir = Path(cert_info['path'])
534+
cert_dir = Path(cert_info["path"])
528535
if cert_dir.exists():
529536
shutil.rmtree(cert_dir)
530537

@@ -534,62 +541,62 @@ def cleanup_expired_certificates(dry_run=True):
534541
def cleanup_old_certificates(days=30, dry_run=True):
535542
"""
536543
Removes TLS certificates older than a specified number of days.
537-
544+
538545
Args:
539546
days (int):
540547
Remove certificates created more than this many days ago. Default is 30.
541548
dry_run (bool):
542549
If True (default), only lists old certificates without deleting them.
543550
Set to False to actually delete old certificates.
544-
551+
545552
Returns:
546553
list: List of certificate paths that were (or would be) removed.
547-
554+
548555
Example:
549556
>>> # Check certificates older than 90 days
550557
>>> old = cleanup_old_certificates(days=90, dry_run=True)
551558
>>> print(f"Found {len(old)} certificates older than 90 days")
552-
>>>
559+
>>>
553560
>>> # Delete certificates older than 30 days
554561
>>> cleanup_old_certificates(days=30, dry_run=False)
555562
"""
556563
import shutil
557-
564+
558565
cutoff_date = datetime.datetime.now() - datetime.timedelta(days=days)
559566
old_certs = []
560-
567+
561568
certificates = list_tls_certificates()
562-
569+
563570
for cert_info in certificates:
564-
if cert_info['created'] < cutoff_date:
565-
old_certs.append(cert_info['path'])
566-
571+
if cert_info["created"] < cutoff_date:
572+
old_certs.append(cert_info["path"])
573+
567574
if not dry_run:
568-
cert_dir = Path(cert_info['path'])
575+
cert_dir = Path(cert_info["path"])
569576
if cert_dir.exists():
570577
shutil.rmtree(cert_dir)
571-
578+
572579
return old_certs
573580

574581

575582
def refresh_tls_cert(cluster_name, namespace, days=30):
576583
"""
577584
Refreshes TLS certificates by removing old ones and generating new ones.
578-
585+
579586
This is useful when the server CA secret has been rotated and existing
580587
client certificates are no longer valid.
581-
588+
582589
Args:
583590
cluster_name (str):
584591
The name of the Ray cluster.
585592
namespace (str):
586593
The Kubernetes namespace where the Ray cluster is located.
587594
days (int):
588595
The number of days for which the new TLS certificate will be valid. Default is 30.
589-
596+
590597
Returns:
591598
bool: True if certificates were successfully refreshed.
592-
599+
593600
Example:
594601
>>> # Server CA was rotated, refresh client certificates
595602
>>> refresh_tls_cert("my-cluster", "default")
@@ -598,12 +605,13 @@ def refresh_tls_cert(cluster_name, namespace, days=30):
598605
"""
599606
# Remove old certificates
600607
cleanup_tls_cert(cluster_name, namespace)
601-
608+
602609
# Generate new ones
603610
generate_tls_cert(cluster_name, namespace, days=days, force_regenerate=True)
604-
611+
605612
return True
606613

614+
607615
def _get_tls_base_dir():
608616
"""
609617
Get the base directory for TLS certificate storage.
@@ -617,14 +625,14 @@ def _get_tls_base_dir():
617625
Path: Base directory for TLS certificates
618626
"""
619627
# Check for explicit override
620-
tls_dir_env = os.environ.get('CODEFLARE_TLS_DIR')
628+
tls_dir_env = os.environ.get("CODEFLARE_TLS_DIR")
621629
if tls_dir_env:
622630
return Path(tls_dir_env)
623631

624632
# Use XDG Base Directory specification
625-
xdg_data_home = os.environ.get('XDG_DATA_HOME')
633+
xdg_data_home = os.environ.get("XDG_DATA_HOME")
626634
if xdg_data_home:
627635
return Path(xdg_data_home) / "codeflare" / "tls"
628636

629637
# Fallback to standard location
630-
return Path.home() / ".local" / "share" / "codeflare" / "tls"
638+
return Path.home() / ".local" / "share" / "codeflare" / "tls"

0 commit comments

Comments
 (0)