VibecoderMcSwaggins commited on
Commit
ae5413a
Β·
1 Parent(s): 949847c

refactor: implement proper middleware architecture (SPEC-21)

Browse files

- Renames src/middleware β†’ src/workflows (accurate naming)
- Creates proper src/middleware with ChatMiddleware implementations
- Implements RetryMiddleware (fixes HuggingFace 429 crashes)
- Implements TokenTrackingMiddleware (enables cost monitoring)
- Updates HuggingFaceChatClient to use new middleware
- Updates tests and imports

docs/specs/SPEC-21-MIDDLEWARE-ARCHITECTURE.md CHANGED
@@ -1,6 +1,6 @@
1
  # SPEC-21: Middleware Architecture Refactor
2
 
3
- **Status:** READY FOR IMPLEMENTATION
4
  **Priority:** P2 (Architectural hygiene + fixes HuggingFace retry bug)
5
  **Effort:** 2 hours
6
  **PR Scope:** Folder rename + new middleware implementations
 
1
  # SPEC-21: Middleware Architecture Refactor
2
 
3
+ **Status:** COMPLETED
4
  **Priority:** P2 (Architectural hygiene + fixes HuggingFace retry bug)
5
  **Effort:** 2 hours
6
  **PR Scope:** Folder rename + new middleware implementations
src/clients/huggingface.py CHANGED
@@ -27,6 +27,8 @@ from agent_framework._types import FunctionCallContent, FunctionResultContent
27
  from agent_framework.observability import use_observability
28
  from huggingface_hub import InferenceClient
29
 
 
 
30
  from src.utils.config import settings
31
 
32
  logger = structlog.get_logger()
@@ -51,7 +53,13 @@ class HuggingFaceChatClient(BaseChatClient): # type: ignore[misc]
51
  api_key: HF_TOKEN (optional, defaults to env var).
52
  **kwargs: Additional arguments passed to BaseChatClient.
53
  """
54
- super().__init__(**kwargs)
 
 
 
 
 
 
55
  # FIX: Use 7B model to stay on HuggingFace native infrastructure (avoid Novita 500s)
56
  self.model_id = model_id or settings.huggingface_model or "Qwen/Qwen2.5-7B-Instruct"
57
  self.api_key = api_key or settings.hf_token
 
27
  from agent_framework.observability import use_observability
28
  from huggingface_hub import InferenceClient
29
 
30
+ from src.middleware.retry import RetryMiddleware
31
+ from src.middleware.token_tracking import TokenTrackingMiddleware
32
  from src.utils.config import settings
33
 
34
  logger = structlog.get_logger()
 
53
  api_key: HF_TOKEN (optional, defaults to env var).
54
  **kwargs: Additional arguments passed to BaseChatClient.
55
  """
56
+ # Create middleware instances
57
+ middleware: list[Any] = [
58
+ RetryMiddleware(max_attempts=3, min_wait=1.0, max_wait=10.0),
59
+ TokenTrackingMiddleware(),
60
+ ]
61
+
62
+ super().__init__(middleware=middleware, **kwargs)
63
  # FIX: Use 7B model to stay on HuggingFace native infrastructure (avoid Novita 500s)
64
  self.model_id = model_id or settings.huggingface_model or "Qwen/Qwen2.5-7B-Instruct"
65
  self.api_key = api_key or settings.hf_token
src/middleware/__init__.py CHANGED
@@ -1 +1,10 @@
1
- """Middleware components for orchestration."""
 
 
 
 
 
 
 
 
 
 
1
+ """Microsoft Agent Framework middleware implementations.
2
+
3
+ These are interceptor-pattern middleware that wrap chat client calls.
4
+ They are NOT workflows - see src/workflows/ for orchestration patterns.
5
+ """
6
+
7
+ from src.middleware.retry import RetryMiddleware
8
+ from src.middleware.token_tracking import TokenTrackingMiddleware
9
+
10
+ __all__ = ["RetryMiddleware", "TokenTrackingMiddleware"]
src/middleware/retry.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Retry middleware for chat clients with exponential backoff."""
2
+
3
+ import asyncio
4
+ from collections.abc import Awaitable, Callable
5
+
6
+ import structlog
7
+ from agent_framework._middleware import ChatContext, ChatMiddleware
8
+
9
+ logger = structlog.get_logger()
10
+
11
+
12
+ class RetryMiddleware(ChatMiddleware):
13
+ """Retries failed chat requests with exponential backoff.
14
+
15
+ This middleware intercepts chat client calls and retries on transient
16
+ errors (rate limits, timeouts, server errors).
17
+
18
+ Attributes:
19
+ max_attempts: Maximum number of attempts (default: 3)
20
+ min_wait: Minimum wait between retries in seconds (default: 1.0)
21
+ max_wait: Maximum wait between retries in seconds (default: 10.0)
22
+ retryable_status_codes: HTTP status codes to retry (default: 429, 500, 502, 503, 504)
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ max_attempts: int = 3,
28
+ min_wait: float = 1.0,
29
+ max_wait: float = 10.0,
30
+ retryable_status_codes: tuple[int, ...] = (429, 500, 502, 503, 504),
31
+ ) -> None:
32
+ self.max_attempts = max_attempts
33
+ self.min_wait = min_wait
34
+ self.max_wait = max_wait
35
+ self.retryable_status_codes = retryable_status_codes
36
+
37
+ def _is_retryable(self, error: Exception) -> bool:
38
+ """Check if error is retryable."""
39
+ # Check for httpx status errors
40
+ if hasattr(error, "response") and hasattr(error.response, "status_code"):
41
+ return error.response.status_code in self.retryable_status_codes
42
+
43
+ # Check for timeout errors
44
+ error_name = type(error).__name__.lower()
45
+ if "timeout" in error_name:
46
+ return True
47
+
48
+ # Check for connection errors
49
+ if "connection" in error_name:
50
+ return True
51
+
52
+ return False
53
+
54
+ def _calculate_wait(self, attempt: int) -> float:
55
+ """Calculate wait time with exponential backoff."""
56
+ wait = self.min_wait * (2**attempt)
57
+ return float(min(wait, self.max_wait))
58
+
59
+ async def process(
60
+ self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
61
+ ) -> None:
62
+ """Process the chat request with retry logic."""
63
+ last_error: Exception | None = None
64
+
65
+ for attempt in range(self.max_attempts):
66
+ try:
67
+ await next(context)
68
+ return # Success - exit retry loop
69
+
70
+ except Exception as e:
71
+ last_error = e
72
+
73
+ if not self._is_retryable(e):
74
+ logger.warning(
75
+ "Non-retryable error",
76
+ error=str(e),
77
+ error_type=type(e).__name__,
78
+ )
79
+ raise # Don't retry non-retryable errors
80
+
81
+ if attempt < self.max_attempts - 1:
82
+ wait_time = self._calculate_wait(attempt)
83
+ logger.info(
84
+ "Retrying after error",
85
+ attempt=attempt + 1,
86
+ max_attempts=self.max_attempts,
87
+ wait_seconds=wait_time,
88
+ error=str(e),
89
+ )
90
+ await asyncio.sleep(wait_time)
91
+
92
+ # All retries exhausted
93
+ logger.error(
94
+ "All retry attempts failed",
95
+ max_attempts=self.max_attempts,
96
+ last_error=str(last_error),
97
+ )
98
+ if last_error:
99
+ raise last_error
src/middleware/token_tracking.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token tracking middleware for monitoring API usage."""
2
+
3
+ from collections.abc import Awaitable, Callable
4
+ from contextvars import ContextVar
5
+
6
+ import structlog
7
+ from agent_framework._middleware import ChatContext, ChatMiddleware
8
+
9
+ logger = structlog.get_logger()
10
+
11
+ # ContextVar for per-request token tracking
12
+ _request_tokens: ContextVar[dict[str, int]] = ContextVar("request_tokens")
13
+
14
+
15
+ class TokenTrackingMiddleware(ChatMiddleware):
16
+ """Tracks token usage across chat requests.
17
+
18
+ This middleware logs token usage after each chat completion
19
+ and maintains running totals for the session.
20
+
21
+ Usage metrics are logged via structlog for observability.
22
+ """
23
+
24
+ def __init__(self) -> None:
25
+ self.total_input_tokens = 0
26
+ self.total_output_tokens = 0
27
+ self.request_count = 0
28
+
29
+ async def process(
30
+ self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
31
+ ) -> None:
32
+ """Process request and track token usage."""
33
+ await next(context)
34
+
35
+ # Extract usage from response if available
36
+ if context.result is None:
37
+ return
38
+
39
+ usage = None
40
+
41
+ # Try to get usage from response
42
+ if hasattr(context.result, "usage"):
43
+ usage = context.result.usage
44
+ elif hasattr(context.result, "messages") and context.result.messages:
45
+ # Check first message for usage metadata
46
+ msg = context.result.messages[0]
47
+ if hasattr(msg, "metadata") and msg.metadata:
48
+ usage = msg.metadata.get("usage")
49
+
50
+ if usage:
51
+ input_tokens = usage.get("input_tokens", 0) or usage.get("prompt_tokens", 0)
52
+ output_tokens = usage.get("output_tokens", 0) or usage.get("completion_tokens", 0)
53
+
54
+ self.total_input_tokens += input_tokens
55
+ self.total_output_tokens += output_tokens
56
+ self.request_count += 1
57
+
58
+ logger.info(
59
+ "Token usage",
60
+ request_input=input_tokens,
61
+ request_output=output_tokens,
62
+ total_input=self.total_input_tokens,
63
+ total_output=self.total_output_tokens,
64
+ total_requests=self.request_count,
65
+ )
66
+
67
+
68
+ def get_token_stats() -> dict[str, int]:
69
+ """Get current request's token usage."""
70
+ try:
71
+ return _request_tokens.get().copy()
72
+ except LookupError:
73
+ return {"input": 0, "output": 0}
src/orchestrators/hierarchical.py CHANGED
@@ -19,11 +19,11 @@ import structlog
19
  from src.agents.judge_agent_llm import LLMSubIterationJudge
20
  from src.agents.magentic_agents import create_search_agent
21
  from src.config.domain import ResearchDomain
22
- from src.middleware.sub_iteration import SubIterationMiddleware, SubIterationTeam
23
  from src.orchestrators.base import OrchestratorProtocol
24
  from src.state import init_magentic_state
25
  from src.utils.models import AgentEvent, OrchestratorConfig
26
  from src.utils.service_loader import get_embedding_service_if_available
 
27
 
28
  logger = structlog.get_logger()
29
 
 
19
  from src.agents.judge_agent_llm import LLMSubIterationJudge
20
  from src.agents.magentic_agents import create_search_agent
21
  from src.config.domain import ResearchDomain
 
22
  from src.orchestrators.base import OrchestratorProtocol
23
  from src.state import init_magentic_state
24
  from src.utils.models import AgentEvent, OrchestratorConfig
25
  from src.utils.service_loader import get_embedding_service_if_available
26
+ from src.workflows.sub_iteration import SubIterationMiddleware, SubIterationTeam
27
 
28
  logger = structlog.get_logger()
29
 
src/{middleware β†’ workflows}/.gitkeep RENAMED
File without changes
src/workflows/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Middleware components for orchestration."""
src/{middleware β†’ workflows}/sub_iteration.py RENAMED
File without changes
tests/unit/middleware/__init__.py ADDED
File without changes
tests/unit/middleware/test_retry.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import AsyncMock, MagicMock
2
+
3
+ import pytest
4
+
5
+ from src.middleware.retry import RetryMiddleware
6
+
7
+
8
+ @pytest.mark.asyncio
9
+ async def test_retry_middleware_succeeds_first_try():
10
+ """RetryMiddleware should pass through on success."""
11
+ middleware = RetryMiddleware(max_attempts=3)
12
+ context = MagicMock()
13
+ next_fn = AsyncMock()
14
+
15
+ await middleware.process(context, next_fn)
16
+
17
+ next_fn.assert_called_once_with(context)
18
+
19
+
20
+ @pytest.mark.asyncio
21
+ async def test_retry_middleware_retries_on_429():
22
+ """RetryMiddleware should retry on 429 rate limit."""
23
+ middleware = RetryMiddleware(max_attempts=3, min_wait=0.01)
24
+ context = MagicMock()
25
+
26
+ # First two calls fail with 429, third succeeds
27
+ call_count = 0
28
+
29
+ async def mock_next(ctx):
30
+ nonlocal call_count
31
+ call_count += 1
32
+ if call_count < 3:
33
+ error = Exception("Rate limited")
34
+ error.response = MagicMock(status_code=429)
35
+ raise error
36
+
37
+ await middleware.process(context, mock_next)
38
+ assert call_count == 3
39
+
40
+
41
+ @pytest.mark.asyncio
42
+ async def test_retry_middleware_raises_after_max_attempts():
43
+ """RetryMiddleware should raise after max attempts exhausted."""
44
+ middleware = RetryMiddleware(max_attempts=2, min_wait=0.01)
45
+ context = MagicMock()
46
+
47
+ async def always_fails(ctx):
48
+ error = Exception("Always fails")
49
+ error.response = MagicMock(status_code=500)
50
+ raise error
51
+
52
+ with pytest.raises(Exception, match="Always fails"):
53
+ await middleware.process(context, always_fails)
tests/unit/middleware/test_token_tracking.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import AsyncMock, MagicMock
2
+
3
+ import pytest
4
+
5
+ from src.middleware.token_tracking import TokenTrackingMiddleware
6
+
7
+
8
+ @pytest.mark.asyncio
9
+ async def test_token_tracking_middleware_counts_tokens():
10
+ """TokenTrackingMiddleware should count tokens from response."""
11
+ middleware = TokenTrackingMiddleware()
12
+ context = MagicMock()
13
+
14
+ # Mock response with usage
15
+ context.result.usage = {"input_tokens": 10, "output_tokens": 20}
16
+
17
+ next_fn = AsyncMock()
18
+
19
+ await middleware.process(context, next_fn)
20
+
21
+ assert middleware.total_input_tokens == 10
22
+ assert middleware.total_output_tokens == 20
23
+ assert middleware.request_count == 1
24
+
25
+
26
+ @pytest.mark.asyncio
27
+ async def test_token_tracking_middleware_handles_no_usage():
28
+ """TokenTrackingMiddleware should handle response without usage gracefully."""
29
+ middleware = TokenTrackingMiddleware()
30
+ context = MagicMock()
31
+ context.result = MagicMock()
32
+ del context.result.usage # Ensure usage attr doesn't exist
33
+ context.result.messages = [] # Ensure no messages
34
+
35
+ next_fn = AsyncMock()
36
+
37
+ await middleware.process(context, next_fn)
38
+
39
+ assert middleware.total_input_tokens == 0
40
+ assert middleware.total_output_tokens == 0
41
+ assert middleware.request_count == 0
tests/unit/test_hierarchical.py CHANGED
@@ -4,8 +4,8 @@ from unittest.mock import AsyncMock
4
 
5
  import pytest
6
 
7
- from src.middleware.sub_iteration import SubIterationMiddleware
8
  from src.utils.models import AssessmentDetails, JudgeAssessment
 
9
 
10
  pytestmark = pytest.mark.unit
11
 
 
4
 
5
  import pytest
6
 
 
7
  from src.utils.models import AssessmentDetails, JudgeAssessment
8
+ from src.workflows.sub_iteration import SubIterationMiddleware
9
 
10
  pytestmark = pytest.mark.unit
11