Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions src/azure-cli-core/azure/cli/core/auth/agentic_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

"""
Support for Entra Agentic Sessions.

When CLI runs inside an agent context (e.g., Copilot, Azure MCP), the orchestrator sets the
COPILOT_AGENT_SESSION_ID environment variable. CLI reads it and passes it to MSAL as both:
- A query parameter (`client_session`) so ESTS can identify the agentic session
- A claims challenge so ESTS embeds an agentic marker claim in the token (and MSAL bypasses
the access token cache to ensure a fresh, agent-tagged token is always fetched)

This enables downstream systems (RBAC, Defender, Purview) to enforce differentiated policies
for agent-driven vs. human-driven operations.
"""

import json
import os

from knack.log import get_logger

logger = get_logger(__name__)

COPILOT_AGENT_SESSION_ID = "COPILOT_AGENT_SESSION_ID"


def build_agentic_session_params():
"""Read COPILOT_AGENT_SESSION_ID and build the agentic claims challenge.

:returns: (session_id, claims_challenge) — both None when env var is not set.
"""
session_id = os.environ.get(COPILOT_AGENT_SESSION_ID) or None
if not session_id:
return None, None

logger.debug("Agentic session detected (COPILOT_AGENT_SESSION_ID is set)")

claims_challenge = json.dumps({
"access_token": {
"xms_cli_sid": {"values": [session_id]}
}
})
return session_id, claims_challenge


def merge_access_token_claims(existing_claims, new_claims):
"""Merge new claims into an existing claims_challenge JSON string.

:param existing_claims: Existing claims_challenge JSON string (or None).
:param new_claims: New claims_challenge JSON string to merge in. Must not be None or empty,
and must contain a non-empty ``access_token`` object.
:returns: Merged claims_challenge JSON string.
:raises ValueError: If ``new_claims`` is None, empty, or does not contain a non-empty
``access_token`` object.
"""
if not new_claims:
raise ValueError("new_claims must not be None or empty")
new_access_token = json.loads(new_claims).get("access_token")
if not new_access_token:
raise ValueError("new_claims must contain a non-empty access_token")

claims_dict = json.loads(existing_claims) if existing_claims else {}
claims_dict["access_token"] = claims_dict.get("access_token") or {}
claims_dict["access_token"].update(new_access_token)
return json.dumps(claims_dict)
21 changes: 21 additions & 0 deletions src/azure-cli-core/azure/cli/core/auth/msal_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,27 @@ def acquire_token(self, scopes, claims_challenge=None, **kwargs):
logger.debug("UserCredential.acquire_token: scopes=%r, claims_challenge=%r, kwargs=%r",
scopes, claims_challenge, kwargs)

# Apply agentic session parameters for user identity flows
from .agentic_session import build_agentic_session_params, merge_access_token_claims
agentic_session_id, agentic_claims = build_agentic_session_params()
if agentic_session_id:
# Both paths: client_session in data and params so eSTS can identify the agentic session
kwargs["data"] = kwargs.get("data") or {}
kwargs["data"]["client_session"] = agentic_session_id
kwargs["params"] = kwargs.get("params") or {}
kwargs["params"]["client_session"] = agentic_session_id
Comment thread
xuming-ms marked this conversation as resolved.

Comment thread
xuming-ms marked this conversation as resolved.
if getattr(self._msal_app, '_enable_broker', False):
# Broker path: claims_challenge flows to MSALRuntime cache key via set_decoded_claims.
# This causes MSAL to skip its local AT cache and forward claims to the broker,
# where requestedClaims becomes part of the C++ cache key.
claims_challenge = merge_access_token_claims(claims_challenge, agentic_claims)
# Non-broker path: client_session in data flows into ext_cache_key (SHA256 hash),
# which partitions the MSAL Python token cache. No claims_challenge needed.

from azure.cli.core.telemetry import set_agentic_session
set_agentic_session(True)

if claims_challenge:
logger.info('Acquiring new access token silently with claims challenge: %s', claims_challenge)
result = self._msal_app.acquire_token_silent_with_error(
Expand Down
215 changes: 215 additions & 0 deletions src/azure-cli-core/azure/cli/core/auth/tests/test_agentic_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import json
import os
import unittest
from unittest.mock import patch

from azure.cli.core.auth.agentic_session import (
COPILOT_AGENT_SESSION_ID,
build_agentic_session_params,
merge_access_token_claims,
)


class TestBuildAgenticSessionParams(unittest.TestCase):

def test_returns_none_when_env_not_set(self):
with patch.dict(os.environ, {}, clear=True):
session_id, claims = build_agentic_session_params()
self.assertIsNone(session_id)
self.assertIsNone(claims)

def test_returns_none_when_env_is_empty_string(self):
with patch.dict(os.environ, {COPILOT_AGENT_SESSION_ID: ""}):
session_id, claims = build_agentic_session_params()
self.assertIsNone(session_id)
self.assertIsNone(claims)

def test_returns_session_id_and_claims(self):
with patch.dict(os.environ, {COPILOT_AGENT_SESSION_ID: "sess-456"}):
session_id, claims = build_agentic_session_params()
self.assertEqual(session_id, "sess-456")
parsed = json.loads(claims)
self.assertEqual(parsed["access_token"]["xms_cli_sid"]["values"], ["sess-456"])

def _agentic_claims(session_id="s1"):
return json.dumps({"access_token": {"xms_cli_sid": {"values": [session_id]}}})


class TestMergeAccessTokenClaims(unittest.TestCase):

# --- Validation ---

def test_raises_when_new_claims_is_none(self):
with self.assertRaises(ValueError):
merge_access_token_claims(None, None)

def test_raises_when_new_access_token_is_null(self):
new = json.dumps({"access_token": None})
with self.assertRaises(ValueError):
merge_access_token_claims(None, new)

# --- Merging ---

def test_merges_into_none(self):
result = merge_access_token_claims(None, _agentic_claims("s1"))
claims = json.loads(result)
self.assertEqual(len(claims), 1)
self.assertEqual(len(claims["access_token"]), 1)
self.assertEqual(claims["access_token"]["xms_cli_sid"], {"values": ["s1"]})

def test_merges_into_existing(self):
existing = json.dumps({"access_token": {"nbf": {"essential": True, "value": "999"}}})
result = merge_access_token_claims(existing, _agentic_claims("s1"))
merged = json.loads(result)
self.assertEqual(len(merged), 1)
self.assertEqual(len(merged["access_token"]), 2)
self.assertEqual(merged["access_token"]["nbf"], {"essential": True, "value": "999"})
self.assertEqual(merged["access_token"]["xms_cli_sid"], {"values": ["s1"]})

def test_preserves_non_access_token_keys(self):
existing = json.dumps({
"access_token": {"nbf": {"essential": True}},
"id_token": {"auth_time": {"essential": True}}
})
result = merge_access_token_claims(existing, _agentic_claims())
merged = json.loads(result)
self.assertEqual(len(merged), 2)
self.assertEqual(len(merged["access_token"]), 2)
self.assertEqual(merged["id_token"], {"auth_time": {"essential": True}})
self.assertEqual(merged["access_token"]["nbf"], {"essential": True})
self.assertEqual(merged["access_token"]["xms_cli_sid"], {"values": ["s1"]})

def test_new_claims_overwrites_existing_key(self):
existing = json.dumps({"access_token": {"xms_cli_sid": {"values": ["old"]}}})
result = merge_access_token_claims(existing, _agentic_claims("new"))
merged = json.loads(result)
self.assertEqual(len(merged), 1)
self.assertEqual(len(merged["access_token"]), 1)
self.assertEqual(merged["access_token"]["xms_cli_sid"], {"values": ["new"]})

def test_creates_access_token_when_missing_in_existing(self):
existing = json.dumps({"id_token": {"auth_time": {"essential": True}}})
result = merge_access_token_claims(existing, _agentic_claims())
merged = json.loads(result)
self.assertEqual(len(merged), 2)
self.assertEqual(len(merged["access_token"]), 1)
self.assertEqual(merged["id_token"], {"auth_time": {"essential": True}})
self.assertEqual(merged["access_token"]["xms_cli_sid"], {"values": ["s1"]})

def test_handles_null_access_token_in_existing(self):
existing = json.dumps({"access_token": None})
result = merge_access_token_claims(existing, _agentic_claims())
merged = json.loads(result)
self.assertEqual(len(merged), 1)
self.assertEqual(len(merged["access_token"]), 1)
self.assertEqual(merged["access_token"]["xms_cli_sid"], {"values": ["s1"]})


class TestUserCredentialAgenticSession(unittest.TestCase):
"""Verify that UserCredential.acquire_token merges agentic claims and passes
client_session param when COPILOT_AGENT_SESSION_ID is set."""

def _build_user_credential(self, enable_broker=False):
"""Build a UserCredential with mocked MSAL app."""
from unittest.mock import MagicMock, PropertyMock
from azure.cli.core.auth.msal_credentials import UserCredential

cred = object.__new__(UserCredential)

cred._msal_app = MagicMock()
cred._msal_app.client_id = "test-client-id"
cred._msal_app._enable_broker = enable_broker
type(cred._msal_app).authority = PropertyMock(return_value=MagicMock(
instance="login.microsoftonline.com",
tenant="test-tenant",
is_adfs=False,
))
cred._account = {
"home_account_id": "uid.utid",
"username": "user@test.com",
}
return cred

@patch.dict(os.environ, {COPILOT_AGENT_SESSION_ID: "agent-sess-1"})
def test_non_broker_passes_data_only(self):
"""Non-broker path: client_session in data for ext_cache_key, no claims_challenge."""
cred = self._build_user_credential(enable_broker=False)
cred._msal_app.acquire_token_silent_with_error.return_value = {
"access_token": "agent-tagged-token",
"token_type": "Bearer",
"expires_in": 3600,
}

result = cred.acquire_token(["https://management.azure.com/.default"])

self.assertEqual(result["access_token"], "agent-tagged-token")

call_kwargs = cred._msal_app.acquire_token_silent_with_error.call_args
self.assertIsNone(call_kwargs.kwargs.get("claims_challenge"))
self.assertEqual(call_kwargs.kwargs["data"], {"client_session": "agent-sess-1"})
self.assertEqual(call_kwargs.kwargs["params"], {"client_session": "agent-sess-1"})

@patch.dict(os.environ, {COPILOT_AGENT_SESSION_ID: "agent-sess-1"})
def test_broker_passes_claims_and_data(self):
"""Broker path: claims_challenge with xms_cli_sid AND client_session in data."""
cred = self._build_user_credential(enable_broker=True)
cred._msal_app.acquire_token_silent_with_error.return_value = {
"access_token": "agent-tagged-token",
"token_type": "Bearer",
"expires_in": 3600,
}

result = cred.acquire_token(["https://management.azure.com/.default"])

self.assertEqual(result["access_token"], "agent-tagged-token")

call_kwargs = cred._msal_app.acquire_token_silent_with_error.call_args
claims = json.loads(call_kwargs.kwargs["claims_challenge"])
self.assertEqual(claims["access_token"]["xms_cli_sid"]["values"], ["agent-sess-1"])
self.assertEqual(call_kwargs.kwargs["data"], {"client_session": "agent-sess-1"})
self.assertEqual(call_kwargs.kwargs["params"], {"client_session": "agent-sess-1"})

@patch.dict(os.environ, {}, clear=True)
def test_no_agentic_params_without_env(self):
"""When COPILOT_AGENT_SESSION_ID is not set, no agentic params are added."""
cred = self._build_user_credential(enable_broker=False)
cred._msal_app.acquire_token_silent_with_error.return_value = {
"access_token": "normal-token",
"token_type": "Bearer",
"expires_in": 3600,
}

result = cred.acquire_token(["https://management.azure.com/.default"])

self.assertEqual(result["access_token"], "normal-token")

call_kwargs = cred._msal_app.acquire_token_silent_with_error.call_args
self.assertIsNone(call_kwargs.kwargs.get("claims_challenge"))
self.assertNotIn("params", call_kwargs.kwargs)

@patch.dict(os.environ, {COPILOT_AGENT_SESSION_ID: "agent-sess-2"})
def test_broker_merges_with_existing_claims(self):
"""Broker path: agentic claims are merged with existing claims_challenge."""
cred = self._build_user_credential(enable_broker=True)
cred._msal_app.acquire_token_silent_with_error.return_value = {
"access_token": "token",
"token_type": "Bearer",
"expires_in": 3600,
}

existing_claims = json.dumps({"access_token": {"nbf": {"essential": True, "value": "999"}}})
cred.acquire_token(["scope"], claims_challenge=existing_claims)

call_kwargs = cred._msal_app.acquire_token_silent_with_error.call_args
claims = json.loads(call_kwargs.kwargs["claims_challenge"])
self.assertEqual(claims["access_token"]["nbf"], {"essential": True, "value": "999"})
self.assertEqual(claims["access_token"]["xms_cli_sid"]["values"], ["agent-sess-2"])


if __name__ == '__main__':
unittest.main()
7 changes: 7 additions & 0 deletions src/azure-cli-core/azure/cli/core/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, correlation_id=None, application=None):
self.enable_broker_on_windows = None
self.msal_telemetry = None
self.login_experience_v2 = None
self.agentic_session = False

def add_event(self, name, properties):
for key in self.instrumentation_key:
Expand Down Expand Up @@ -239,6 +240,7 @@ def _get_azure_cli_properties(self):
set_custom_properties(result, 'EnableBrokerOnWindows', str(self.enable_broker_on_windows))
set_custom_properties(result, 'MsalTelemetry', self.msal_telemetry)
set_custom_properties(result, 'LoginExperienceV2', str(self.login_experience_v2))
set_custom_properties(result, 'AgenticSession', str(self.agentic_session))

return result

Expand Down Expand Up @@ -497,6 +499,11 @@ def set_msal_telemetry(msal_telemetry):
@decorators.suppress_all_exceptions()
def set_login_experience_v2(login_experience_v2):
_session.login_experience_v2 = login_experience_v2


@decorators.suppress_all_exceptions()
def set_agentic_session(agentic_session):
_session.agentic_session = agentic_session
# endregion


Expand Down
4 changes: 2 additions & 2 deletions src/azure-cli-core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
'knack~=0.11.0',
'microsoft-security-utilities-secret-masker~=1.0.0b4',
'msal-extensions==1.3.1',
'msal[broker]==1.35.1; sys_platform == "win32"',
'msal==1.35.1; sys_platform != "win32"',
'msal[broker]==1.36.0; sys_platform == "win32"',
'msal==1.36.0; sys_platform != "win32"',
'packaging>=20.9',
'pkginfo>=1.5.0.1',
# psutil can't install on cygwin: https://github.com/Azure/azure-cli/issues/9399
Expand Down
Loading
Loading