Commit
Β·
ae5413a
1
Parent(s):
949847c
refactor: implement proper middleware architecture (SPEC-21)
Browse files- Renames src/middleware β src/workflows (accurate naming)
- Creates proper src/middleware with ChatMiddleware implementations
- Implements RetryMiddleware (fixes HuggingFace 429 crashes)
- Implements TokenTrackingMiddleware (enables cost monitoring)
- Updates HuggingFaceChatClient to use new middleware
- Updates tests and imports
- docs/specs/SPEC-21-MIDDLEWARE-ARCHITECTURE.md +1 -1
- src/clients/huggingface.py +9 -1
- src/middleware/__init__.py +10 -1
- src/middleware/retry.py +99 -0
- src/middleware/token_tracking.py +73 -0
- src/orchestrators/hierarchical.py +1 -1
- src/{middleware β workflows}/.gitkeep +0 -0
- src/workflows/__init__.py +1 -0
- src/{middleware β workflows}/sub_iteration.py +0 -0
- tests/unit/middleware/__init__.py +0 -0
- tests/unit/middleware/test_retry.py +53 -0
- tests/unit/middleware/test_token_tracking.py +41 -0
- tests/unit/test_hierarchical.py +1 -1
docs/specs/SPEC-21-MIDDLEWARE-ARCHITECTURE.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# SPEC-21: Middleware Architecture Refactor
|
| 2 |
|
| 3 |
-
**Status:**
|
| 4 |
**Priority:** P2 (Architectural hygiene + fixes HuggingFace retry bug)
|
| 5 |
**Effort:** 2 hours
|
| 6 |
**PR Scope:** Folder rename + new middleware implementations
|
|
|
|
| 1 |
# SPEC-21: Middleware Architecture Refactor
|
| 2 |
|
| 3 |
+
**Status:** COMPLETED
|
| 4 |
**Priority:** P2 (Architectural hygiene + fixes HuggingFace retry bug)
|
| 5 |
**Effort:** 2 hours
|
| 6 |
**PR Scope:** Folder rename + new middleware implementations
|
src/clients/huggingface.py
CHANGED
|
@@ -27,6 +27,8 @@ from agent_framework._types import FunctionCallContent, FunctionResultContent
|
|
| 27 |
from agent_framework.observability import use_observability
|
| 28 |
from huggingface_hub import InferenceClient
|
| 29 |
|
|
|
|
|
|
|
| 30 |
from src.utils.config import settings
|
| 31 |
|
| 32 |
logger = structlog.get_logger()
|
|
@@ -51,7 +53,13 @@ class HuggingFaceChatClient(BaseChatClient): # type: ignore[misc]
|
|
| 51 |
api_key: HF_TOKEN (optional, defaults to env var).
|
| 52 |
**kwargs: Additional arguments passed to BaseChatClient.
|
| 53 |
"""
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# FIX: Use 7B model to stay on HuggingFace native infrastructure (avoid Novita 500s)
|
| 56 |
self.model_id = model_id or settings.huggingface_model or "Qwen/Qwen2.5-7B-Instruct"
|
| 57 |
self.api_key = api_key or settings.hf_token
|
|
|
|
| 27 |
from agent_framework.observability import use_observability
|
| 28 |
from huggingface_hub import InferenceClient
|
| 29 |
|
| 30 |
+
from src.middleware.retry import RetryMiddleware
|
| 31 |
+
from src.middleware.token_tracking import TokenTrackingMiddleware
|
| 32 |
from src.utils.config import settings
|
| 33 |
|
| 34 |
logger = structlog.get_logger()
|
|
|
|
| 53 |
api_key: HF_TOKEN (optional, defaults to env var).
|
| 54 |
**kwargs: Additional arguments passed to BaseChatClient.
|
| 55 |
"""
|
| 56 |
+
# Create middleware instances
|
| 57 |
+
middleware: list[Any] = [
|
| 58 |
+
RetryMiddleware(max_attempts=3, min_wait=1.0, max_wait=10.0),
|
| 59 |
+
TokenTrackingMiddleware(),
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
super().__init__(middleware=middleware, **kwargs)
|
| 63 |
# FIX: Use 7B model to stay on HuggingFace native infrastructure (avoid Novita 500s)
|
| 64 |
self.model_id = model_id or settings.huggingface_model or "Qwen/Qwen2.5-7B-Instruct"
|
| 65 |
self.api_key = api_key or settings.hf_token
|
src/middleware/__init__.py
CHANGED
|
@@ -1 +1,10 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Microsoft Agent Framework middleware implementations.
|
| 2 |
+
|
| 3 |
+
These are interceptor-pattern middleware that wrap chat client calls.
|
| 4 |
+
They are NOT workflows - see src/workflows/ for orchestration patterns.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from src.middleware.retry import RetryMiddleware
|
| 8 |
+
from src.middleware.token_tracking import TokenTrackingMiddleware
|
| 9 |
+
|
| 10 |
+
__all__ = ["RetryMiddleware", "TokenTrackingMiddleware"]
|
src/middleware/retry.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Retry middleware for chat clients with exponential backoff."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from collections.abc import Awaitable, Callable
|
| 5 |
+
|
| 6 |
+
import structlog
|
| 7 |
+
from agent_framework._middleware import ChatContext, ChatMiddleware
|
| 8 |
+
|
| 9 |
+
logger = structlog.get_logger()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RetryMiddleware(ChatMiddleware):
|
| 13 |
+
"""Retries failed chat requests with exponential backoff.
|
| 14 |
+
|
| 15 |
+
This middleware intercepts chat client calls and retries on transient
|
| 16 |
+
errors (rate limits, timeouts, server errors).
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
max_attempts: Maximum number of attempts (default: 3)
|
| 20 |
+
min_wait: Minimum wait between retries in seconds (default: 1.0)
|
| 21 |
+
max_wait: Maximum wait between retries in seconds (default: 10.0)
|
| 22 |
+
retryable_status_codes: HTTP status codes to retry (default: 429, 500, 502, 503, 504)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
max_attempts: int = 3,
|
| 28 |
+
min_wait: float = 1.0,
|
| 29 |
+
max_wait: float = 10.0,
|
| 30 |
+
retryable_status_codes: tuple[int, ...] = (429, 500, 502, 503, 504),
|
| 31 |
+
) -> None:
|
| 32 |
+
self.max_attempts = max_attempts
|
| 33 |
+
self.min_wait = min_wait
|
| 34 |
+
self.max_wait = max_wait
|
| 35 |
+
self.retryable_status_codes = retryable_status_codes
|
| 36 |
+
|
| 37 |
+
def _is_retryable(self, error: Exception) -> bool:
|
| 38 |
+
"""Check if error is retryable."""
|
| 39 |
+
# Check for httpx status errors
|
| 40 |
+
if hasattr(error, "response") and hasattr(error.response, "status_code"):
|
| 41 |
+
return error.response.status_code in self.retryable_status_codes
|
| 42 |
+
|
| 43 |
+
# Check for timeout errors
|
| 44 |
+
error_name = type(error).__name__.lower()
|
| 45 |
+
if "timeout" in error_name:
|
| 46 |
+
return True
|
| 47 |
+
|
| 48 |
+
# Check for connection errors
|
| 49 |
+
if "connection" in error_name:
|
| 50 |
+
return True
|
| 51 |
+
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
def _calculate_wait(self, attempt: int) -> float:
|
| 55 |
+
"""Calculate wait time with exponential backoff."""
|
| 56 |
+
wait = self.min_wait * (2**attempt)
|
| 57 |
+
return float(min(wait, self.max_wait))
|
| 58 |
+
|
| 59 |
+
async def process(
|
| 60 |
+
self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
| 61 |
+
) -> None:
|
| 62 |
+
"""Process the chat request with retry logic."""
|
| 63 |
+
last_error: Exception | None = None
|
| 64 |
+
|
| 65 |
+
for attempt in range(self.max_attempts):
|
| 66 |
+
try:
|
| 67 |
+
await next(context)
|
| 68 |
+
return # Success - exit retry loop
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
last_error = e
|
| 72 |
+
|
| 73 |
+
if not self._is_retryable(e):
|
| 74 |
+
logger.warning(
|
| 75 |
+
"Non-retryable error",
|
| 76 |
+
error=str(e),
|
| 77 |
+
error_type=type(e).__name__,
|
| 78 |
+
)
|
| 79 |
+
raise # Don't retry non-retryable errors
|
| 80 |
+
|
| 81 |
+
if attempt < self.max_attempts - 1:
|
| 82 |
+
wait_time = self._calculate_wait(attempt)
|
| 83 |
+
logger.info(
|
| 84 |
+
"Retrying after error",
|
| 85 |
+
attempt=attempt + 1,
|
| 86 |
+
max_attempts=self.max_attempts,
|
| 87 |
+
wait_seconds=wait_time,
|
| 88 |
+
error=str(e),
|
| 89 |
+
)
|
| 90 |
+
await asyncio.sleep(wait_time)
|
| 91 |
+
|
| 92 |
+
# All retries exhausted
|
| 93 |
+
logger.error(
|
| 94 |
+
"All retry attempts failed",
|
| 95 |
+
max_attempts=self.max_attempts,
|
| 96 |
+
last_error=str(last_error),
|
| 97 |
+
)
|
| 98 |
+
if last_error:
|
| 99 |
+
raise last_error
|
src/middleware/token_tracking.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Token tracking middleware for monitoring API usage."""
|
| 2 |
+
|
| 3 |
+
from collections.abc import Awaitable, Callable
|
| 4 |
+
from contextvars import ContextVar
|
| 5 |
+
|
| 6 |
+
import structlog
|
| 7 |
+
from agent_framework._middleware import ChatContext, ChatMiddleware
|
| 8 |
+
|
| 9 |
+
logger = structlog.get_logger()
|
| 10 |
+
|
| 11 |
+
# ContextVar for per-request token tracking
|
| 12 |
+
_request_tokens: ContextVar[dict[str, int]] = ContextVar("request_tokens")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TokenTrackingMiddleware(ChatMiddleware):
|
| 16 |
+
"""Tracks token usage across chat requests.
|
| 17 |
+
|
| 18 |
+
This middleware logs token usage after each chat completion
|
| 19 |
+
and maintains running totals for the session.
|
| 20 |
+
|
| 21 |
+
Usage metrics are logged via structlog for observability.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self) -> None:
|
| 25 |
+
self.total_input_tokens = 0
|
| 26 |
+
self.total_output_tokens = 0
|
| 27 |
+
self.request_count = 0
|
| 28 |
+
|
| 29 |
+
async def process(
|
| 30 |
+
self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
| 31 |
+
) -> None:
|
| 32 |
+
"""Process request and track token usage."""
|
| 33 |
+
await next(context)
|
| 34 |
+
|
| 35 |
+
# Extract usage from response if available
|
| 36 |
+
if context.result is None:
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
usage = None
|
| 40 |
+
|
| 41 |
+
# Try to get usage from response
|
| 42 |
+
if hasattr(context.result, "usage"):
|
| 43 |
+
usage = context.result.usage
|
| 44 |
+
elif hasattr(context.result, "messages") and context.result.messages:
|
| 45 |
+
# Check first message for usage metadata
|
| 46 |
+
msg = context.result.messages[0]
|
| 47 |
+
if hasattr(msg, "metadata") and msg.metadata:
|
| 48 |
+
usage = msg.metadata.get("usage")
|
| 49 |
+
|
| 50 |
+
if usage:
|
| 51 |
+
input_tokens = usage.get("input_tokens", 0) or usage.get("prompt_tokens", 0)
|
| 52 |
+
output_tokens = usage.get("output_tokens", 0) or usage.get("completion_tokens", 0)
|
| 53 |
+
|
| 54 |
+
self.total_input_tokens += input_tokens
|
| 55 |
+
self.total_output_tokens += output_tokens
|
| 56 |
+
self.request_count += 1
|
| 57 |
+
|
| 58 |
+
logger.info(
|
| 59 |
+
"Token usage",
|
| 60 |
+
request_input=input_tokens,
|
| 61 |
+
request_output=output_tokens,
|
| 62 |
+
total_input=self.total_input_tokens,
|
| 63 |
+
total_output=self.total_output_tokens,
|
| 64 |
+
total_requests=self.request_count,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_token_stats() -> dict[str, int]:
|
| 69 |
+
"""Get current request's token usage."""
|
| 70 |
+
try:
|
| 71 |
+
return _request_tokens.get().copy()
|
| 72 |
+
except LookupError:
|
| 73 |
+
return {"input": 0, "output": 0}
|
src/orchestrators/hierarchical.py
CHANGED
|
@@ -19,11 +19,11 @@ import structlog
|
|
| 19 |
from src.agents.judge_agent_llm import LLMSubIterationJudge
|
| 20 |
from src.agents.magentic_agents import create_search_agent
|
| 21 |
from src.config.domain import ResearchDomain
|
| 22 |
-
from src.middleware.sub_iteration import SubIterationMiddleware, SubIterationTeam
|
| 23 |
from src.orchestrators.base import OrchestratorProtocol
|
| 24 |
from src.state import init_magentic_state
|
| 25 |
from src.utils.models import AgentEvent, OrchestratorConfig
|
| 26 |
from src.utils.service_loader import get_embedding_service_if_available
|
|
|
|
| 27 |
|
| 28 |
logger = structlog.get_logger()
|
| 29 |
|
|
|
|
| 19 |
from src.agents.judge_agent_llm import LLMSubIterationJudge
|
| 20 |
from src.agents.magentic_agents import create_search_agent
|
| 21 |
from src.config.domain import ResearchDomain
|
|
|
|
| 22 |
from src.orchestrators.base import OrchestratorProtocol
|
| 23 |
from src.state import init_magentic_state
|
| 24 |
from src.utils.models import AgentEvent, OrchestratorConfig
|
| 25 |
from src.utils.service_loader import get_embedding_service_if_available
|
| 26 |
+
from src.workflows.sub_iteration import SubIterationMiddleware, SubIterationTeam
|
| 27 |
|
| 28 |
logger = structlog.get_logger()
|
| 29 |
|
src/{middleware β workflows}/.gitkeep
RENAMED
|
File without changes
|
src/workflows/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Middleware components for orchestration."""
|
src/{middleware β workflows}/sub_iteration.py
RENAMED
|
File without changes
|
tests/unit/middleware/__init__.py
ADDED
|
File without changes
|
tests/unit/middleware/test_retry.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import AsyncMock, MagicMock
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from src.middleware.retry import RetryMiddleware
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.mark.asyncio
|
| 9 |
+
async def test_retry_middleware_succeeds_first_try():
|
| 10 |
+
"""RetryMiddleware should pass through on success."""
|
| 11 |
+
middleware = RetryMiddleware(max_attempts=3)
|
| 12 |
+
context = MagicMock()
|
| 13 |
+
next_fn = AsyncMock()
|
| 14 |
+
|
| 15 |
+
await middleware.process(context, next_fn)
|
| 16 |
+
|
| 17 |
+
next_fn.assert_called_once_with(context)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@pytest.mark.asyncio
|
| 21 |
+
async def test_retry_middleware_retries_on_429():
|
| 22 |
+
"""RetryMiddleware should retry on 429 rate limit."""
|
| 23 |
+
middleware = RetryMiddleware(max_attempts=3, min_wait=0.01)
|
| 24 |
+
context = MagicMock()
|
| 25 |
+
|
| 26 |
+
# First two calls fail with 429, third succeeds
|
| 27 |
+
call_count = 0
|
| 28 |
+
|
| 29 |
+
async def mock_next(ctx):
|
| 30 |
+
nonlocal call_count
|
| 31 |
+
call_count += 1
|
| 32 |
+
if call_count < 3:
|
| 33 |
+
error = Exception("Rate limited")
|
| 34 |
+
error.response = MagicMock(status_code=429)
|
| 35 |
+
raise error
|
| 36 |
+
|
| 37 |
+
await middleware.process(context, mock_next)
|
| 38 |
+
assert call_count == 3
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@pytest.mark.asyncio
|
| 42 |
+
async def test_retry_middleware_raises_after_max_attempts():
|
| 43 |
+
"""RetryMiddleware should raise after max attempts exhausted."""
|
| 44 |
+
middleware = RetryMiddleware(max_attempts=2, min_wait=0.01)
|
| 45 |
+
context = MagicMock()
|
| 46 |
+
|
| 47 |
+
async def always_fails(ctx):
|
| 48 |
+
error = Exception("Always fails")
|
| 49 |
+
error.response = MagicMock(status_code=500)
|
| 50 |
+
raise error
|
| 51 |
+
|
| 52 |
+
with pytest.raises(Exception, match="Always fails"):
|
| 53 |
+
await middleware.process(context, always_fails)
|
tests/unit/middleware/test_token_tracking.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import AsyncMock, MagicMock
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from src.middleware.token_tracking import TokenTrackingMiddleware
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.mark.asyncio
|
| 9 |
+
async def test_token_tracking_middleware_counts_tokens():
|
| 10 |
+
"""TokenTrackingMiddleware should count tokens from response."""
|
| 11 |
+
middleware = TokenTrackingMiddleware()
|
| 12 |
+
context = MagicMock()
|
| 13 |
+
|
| 14 |
+
# Mock response with usage
|
| 15 |
+
context.result.usage = {"input_tokens": 10, "output_tokens": 20}
|
| 16 |
+
|
| 17 |
+
next_fn = AsyncMock()
|
| 18 |
+
|
| 19 |
+
await middleware.process(context, next_fn)
|
| 20 |
+
|
| 21 |
+
assert middleware.total_input_tokens == 10
|
| 22 |
+
assert middleware.total_output_tokens == 20
|
| 23 |
+
assert middleware.request_count == 1
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@pytest.mark.asyncio
|
| 27 |
+
async def test_token_tracking_middleware_handles_no_usage():
|
| 28 |
+
"""TokenTrackingMiddleware should handle response without usage gracefully."""
|
| 29 |
+
middleware = TokenTrackingMiddleware()
|
| 30 |
+
context = MagicMock()
|
| 31 |
+
context.result = MagicMock()
|
| 32 |
+
del context.result.usage # Ensure usage attr doesn't exist
|
| 33 |
+
context.result.messages = [] # Ensure no messages
|
| 34 |
+
|
| 35 |
+
next_fn = AsyncMock()
|
| 36 |
+
|
| 37 |
+
await middleware.process(context, next_fn)
|
| 38 |
+
|
| 39 |
+
assert middleware.total_input_tokens == 0
|
| 40 |
+
assert middleware.total_output_tokens == 0
|
| 41 |
+
assert middleware.request_count == 0
|
tests/unit/test_hierarchical.py
CHANGED
|
@@ -4,8 +4,8 @@ from unittest.mock import AsyncMock
|
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
|
| 7 |
-
from src.middleware.sub_iteration import SubIterationMiddleware
|
| 8 |
from src.utils.models import AssessmentDetails, JudgeAssessment
|
|
|
|
| 9 |
|
| 10 |
pytestmark = pytest.mark.unit
|
| 11 |
|
|
|
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
|
|
|
|
| 7 |
from src.utils.models import AssessmentDetails, JudgeAssessment
|
| 8 |
+
from src.workflows.sub_iteration import SubIterationMiddleware
|
| 9 |
|
| 10 |
pytestmark = pytest.mark.unit
|
| 11 |
|