Source code for selenium.webdriver.common.api_request_context

# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""APIRequestContext for making HTTP requests with browser cookie synchronization."""

import json
import logging
import pathlib
import time
import urllib.parse
from email.utils import parsedate_to_datetime
from http.client import responses as http_status_phrases
from typing import TYPE_CHECKING, Any

import urllib3
from urllib3.util.retry import Retry

if TYPE_CHECKING:
    from selenium.webdriver.remote.webdriver import WebDriver

logger = logging.getLogger(__name__)


[docs] class APIRequestFailure(Exception): """Raised when an API request returns a non-2xx status and fail_on_status_code is True. Attributes: response: The APIResponse that triggered the failure. """ def __init__(self, response: "APIResponse") -> None: self.response = response super().__init__(f"{response.status} {response.status_text}: {response.url}")
[docs] class APIResponse: """Represents an HTTP response from an API request. Attributes: status: HTTP status code. status_text: HTTP status text. headers: Response headers as a dict. url: The request URL. """ def __init__(self, status: int, status_text: str, headers: dict[str, str], url: str, body: bytes) -> None: self.status = status self.status_text = status_text self.headers = headers self.url = url self._body = body @property def ok(self) -> bool: """Whether the response status is in the 200-299 range.""" return 200 <= self.status <= 299
[docs] def json(self) -> Any: """Parse the response body as JSON. Returns: The parsed JSON object. """ return json.loads(self._body)
[docs] def text(self) -> str: """Decode the response body as UTF-8 text. Returns: The response body as a string. """ return self._body.decode("utf-8")
[docs] def body(self) -> bytes: """Return the raw response body bytes. Returns: The response body as bytes. """ return self._body
[docs] def dispose(self) -> None: """Free the response body memory.""" self._body = b""
def _cookie_matches(cookie: dict, url: str, default_domain: str = "") -> bool: """Check if a browser cookie should be sent with a request to the given URL. Evaluates expiry, domain, path, and secure attribute matching per RFC 6265. Args: cookie: A cookie dict from driver.get_cookies(). url: The target request URL. default_domain: Fallback domain for host-only cookies (no domain attribute). When a cookie has no domain, it only matches if the request hostname equals this value. If empty and cookie has no domain, the cookie is skipped. Returns: True if the cookie matches the URL. """ # Expiry check — skip expired cookies expiry = cookie.get("expiry") if expiry is not None and expiry <= int(time.time()): return False parsed = urllib.parse.urlparse(url) hostname = parsed.hostname or "" path = parsed.path or "/" scheme = parsed.scheme or "http" # Domain matching (RFC 6265 section 5.1.3) cookie_domain = cookie.get("domain", "") if not cookie_domain: # Host-only cookie — must match the origin host exactly if not default_domain or hostname != default_domain: return False elif cookie_domain.startswith("."): # .example.com matches example.com and sub.example.com if not (hostname == cookie_domain[1:] or hostname.endswith(cookie_domain)): return False else: if hostname != cookie_domain: return False # Path matching (RFC 6265 section 5.1.4) cookie_path = cookie.get("path", "/") if cookie_path == "/": pass # root path matches everything elif path != cookie_path and not path.startswith(cookie_path + "/"): return False # Secure matching if cookie.get("secure", False) and scheme != "https": return False return True def _parse_set_cookie(header_value: str) -> dict: """Parse a single Set-Cookie header value into a cookie dict. Uses manual parsing instead of http.cookies.SimpleCookie which is too strict for real-world Set-Cookie headers. Args: header_value: The Set-Cookie header string. Returns: A dict with cookie attributes suitable for driver.add_cookie(). """ parts = header_value.split(";") name_value = parts[0].strip() eq_idx = name_value.find("=") if eq_idx == -1: return {} name = name_value[:eq_idx].strip() value = name_value[eq_idx + 1 :].strip() cookie: dict[str, Any] = {"name": name, "value": value} has_max_age = False for part in parts[1:]: part = part.strip() if not part: continue if "=" in part: attr_name, attr_value = part.split("=", 1) attr_name = attr_name.strip().lower() attr_value = attr_value.strip() else: attr_name = part.strip().lower() attr_value = "" if attr_name == "domain": cookie["domain"] = attr_value elif attr_name == "path": cookie["path"] = attr_value elif attr_name == "secure": cookie["secure"] = True elif attr_name == "httponly": cookie["httpOnly"] = True elif attr_name == "samesite": cookie["sameSite"] = attr_value elif attr_name == "max-age": try: max_age = int(attr_value) cookie["expiry"] = int(time.time()) + max_age has_max_age = True except ValueError: pass elif attr_name == "expires" and not has_max_age: # RFC 6265 §5.3: Max-Age takes precedence over Expires try: dt = parsedate_to_datetime(attr_value) cookie["expiry"] = int(dt.timestamp()) except (ValueError, TypeError): pass return cookie def _get_set_cookie_headers(resp: urllib3.BaseHTTPResponse) -> list[str]: """Extract all Set-Cookie header values from a urllib3 response. Args: resp: The urllib3 HTTP response. Returns: A list of Set-Cookie header strings. """ if hasattr(resp.headers, "getlist"): headers = resp.headers.getlist("Set-Cookie") if headers: return headers sc = resp.headers.get("Set-Cookie") return [sc] if sc else [] def _resolve_redirect_url(resp: urllib3.BaseHTTPResponse, original_url: str) -> str: """Return the final URL after any redirects. urllib3's retry history records each hop. When redirects occurred, the last entry's redirect_location resolved against its URL gives the final destination. When no redirects occurred, the original request URL is returned unchanged. """ history = resp.retries.history if resp.retries else () if history: last = history[-1] if last.url and last.redirect_location: return urllib.parse.urljoin(last.url, last.redirect_location) return original_url class _BaseRequestContext: """Base class with shared HTTP request logic for API request contexts.""" def __init__( self, base_url: str = "", extra_headers: dict[str, str] | None = None, timeout: float = 30.0, max_redirects: int = 10, fail_on_status_code: bool = False, ) -> None: self._base_url = base_url self._extra_headers = extra_headers or {} self._timeout = timeout self._max_redirects = max_redirects self._fail_on_status_code = fail_on_status_code self._pool = urllib3.PoolManager() def get(self, url: str, **kwargs: Any) -> APIResponse: """Send a GET request. Args: url: The request URL (absolute or relative to base_url). **kwargs: Optional arguments: headers, params, timeout, max_redirects, fail_on_status_code. Returns: An APIResponse object. """ return self._fetch(url, "GET", **kwargs) def post(self, url: str, **kwargs: Any) -> APIResponse: """Send a POST request. Args: url: The request URL (absolute or relative to base_url). **kwargs: Optional arguments: headers, params, data, form, json_data, timeout, max_redirects, fail_on_status_code. Returns: An APIResponse object. """ return self._fetch(url, "POST", **kwargs) def put(self, url: str, **kwargs: Any) -> APIResponse: """Send a PUT request. Args: url: The request URL (absolute or relative to base_url). **kwargs: Optional arguments: headers, params, data, form, json_data, timeout, max_redirects, fail_on_status_code. Returns: An APIResponse object. """ return self._fetch(url, "PUT", **kwargs) def patch(self, url: str, **kwargs: Any) -> APIResponse: """Send a PATCH request. Args: url: The request URL (absolute or relative to base_url). **kwargs: Optional arguments: headers, params, data, form, json_data, timeout, max_redirects, fail_on_status_code. Returns: An APIResponse object. """ return self._fetch(url, "PATCH", **kwargs) def delete(self, url: str, **kwargs: Any) -> APIResponse: """Send a DELETE request. Args: url: The request URL (absolute or relative to base_url). **kwargs: Optional arguments: headers, params, data, form, json_data, timeout, max_redirects, fail_on_status_code. Returns: An APIResponse object. """ return self._fetch(url, "DELETE", **kwargs) def head(self, url: str, **kwargs: Any) -> APIResponse: """Send a HEAD request. Args: url: The request URL (absolute or relative to base_url). **kwargs: Optional arguments: headers, params, timeout, max_redirects, fail_on_status_code. Returns: An APIResponse object. """ return self._fetch(url, "HEAD", **kwargs) def fetch(self, url: str, method: str = "GET", **kwargs: Any) -> APIResponse: """Send an HTTP request with a custom method. Args: url: The request URL (absolute or relative to base_url). method: The HTTP method to use. **kwargs: Optional arguments: headers, params, data, form, json_data, timeout, max_redirects, fail_on_status_code. Returns: An APIResponse object. """ return self._fetch(url, method, **kwargs) def dispose(self) -> None: """Close the underlying connection pool.""" self._pool.clear() def _resolve_url(self, url: str) -> str: """Resolve a URL, prepending base_url for relative paths.""" if not url.startswith(("http://", "https://")): return self._base_url.rstrip("/") + "/" + url.lstrip("/") return url def _build_headers(self, kwargs: dict[str, Any]) -> dict[str, str]: """Merge extra_headers with per-request headers.""" headers = dict(self._extra_headers) if kwargs.get("headers"): headers.update(kwargs["headers"]) return headers def _prepare_body(self, headers: dict[str, str], kwargs: dict[str, Any]) -> bytes | None: """Prepare the request body from json_data, form, or data kwargs. Priority: json_data > form > data. Only one should be provided. """ json_data = kwargs.get("json_data") form = kwargs.get("form") data = kwargs.get("data") if json_data is not None: headers.setdefault("Content-Type", "application/json") return json.dumps(json_data).encode("utf-8") elif form is not None: headers.setdefault("Content-Type", "application/x-www-form-urlencoded") return urllib.parse.urlencode(form).encode("utf-8") elif data is not None: if isinstance(data, dict): headers.setdefault("Content-Type", "application/x-www-form-urlencoded") return urllib.parse.urlencode(data).encode("utf-8") elif isinstance(data, str): return data.encode("utf-8") elif isinstance(data, bytes): return data return None def _append_params(self, url: str, kwargs: dict[str, Any]) -> str: """Append query parameters to the URL.""" params = kwargs.get("params") if params: separator = "&" if "?" in url else "?" return url + separator + urllib.parse.urlencode(params) return url def _execute_request( self, method: str, url: str, headers: dict[str, str], body: bytes | None, kwargs: dict[str, Any] ) -> urllib3.BaseHTTPResponse: """Execute the HTTP request via urllib3.""" timeout = kwargs.get("timeout", self._timeout) max_redirects = kwargs.get("max_redirects", self._max_redirects) follow = max_redirects > 0 retries = Retry( connect=0, read=0, status=0, other=0, redirect=max_redirects if follow else 0, raise_on_redirect=False, ) return self._pool.request( method, url, headers=headers, body=body, timeout=timeout, redirect=follow, retries=retries, preload_content=True, ) def _build_response(self, resp: urllib3.BaseHTTPResponse, url: str) -> APIResponse: """Build an APIResponse from a urllib3 response.""" # Merge duplicate headers per RFC 7230 §3.2.2 (combine with ", ") resp_headers: dict[str, str] = {} for k, v in resp.headers.items(): key = k.lower() if key in resp_headers: resp_headers[key] = resp_headers[key] + ", " + v else: resp_headers[key] = v # urllib3 2.x removed resp.reason; fall back to stdlib phrase lookup reason = getattr(resp, "reason", None) status_text = reason or http_status_phrases.get(resp.status, "") return APIResponse( status=resp.status, status_text=status_text, headers=resp_headers, url=url, body=resp.data, ) def _get_cookies_for_request(self, url: str) -> list[dict]: """Get cookies that should be sent with the request. Overridden by subclasses.""" return [] def _handle_response_cookies(self, set_cookie_headers: list[str], url: str) -> None: """Process Set-Cookie headers from the response. Overridden by subclasses.""" def _fetch(self, url: str, method: str, **kwargs: Any) -> APIResponse: """Execute an HTTP request with cookie handling. Args: url: The request URL. method: The HTTP method. **kwargs: Optional arguments. Returns: An APIResponse object. """ url = self._resolve_url(url) headers = self._build_headers(kwargs) # Apply cookies matching_cookies = self._get_cookies_for_request(url) if matching_cookies: cookie_header = "; ".join(f"{c['name']}={c['value']}" for c in matching_cookies) if "Cookie" in headers: headers["Cookie"] = headers["Cookie"] + "; " + cookie_header else: headers["Cookie"] = cookie_header body = self._prepare_body(headers, kwargs) url = self._append_params(url, kwargs) resp = self._execute_request(method, url, headers, body, kwargs) # After redirects, associate cookies with the final destination's # origin, not the initial request URL. final_url = _resolve_redirect_url(resp, url) # Process response cookies set_cookie_headers = _get_set_cookie_headers(resp) if set_cookie_headers: self._handle_response_cookies(set_cookie_headers, final_url) response = self._build_response(resp, final_url) fail = kwargs.get("fail_on_status_code", self._fail_on_status_code) if fail and not response.ok: raise APIRequestFailure(response) return response
[docs] class APIRequestContext(_BaseRequestContext): """Makes HTTP requests with automatic browser cookie synchronization. Cookies from the browser session are sent with API requests, and cookies from API responses are synced back to the browser. Args: driver: The WebDriver instance to sync cookies with. base_url: Optional base URL prepended to relative request paths. extra_headers: Optional headers included in every request. timeout: Default request timeout in seconds. max_redirects: Maximum number of redirects to follow. fail_on_status_code: If True, raise APIRequestFailure for non-2xx responses. """ def __init__( self, driver: "WebDriver", base_url: str = "", extra_headers: dict[str, str] | None = None, timeout: float = 30.0, max_redirects: int = 10, fail_on_status_code: bool = False, ) -> None: super().__init__( base_url=base_url, extra_headers=extra_headers, timeout=timeout, max_redirects=max_redirects, fail_on_status_code=fail_on_status_code, ) self._driver = driver
[docs] def new_context( self, base_url: str = "", extra_headers: dict[str, str] | None = None, storage_state: dict | str | pathlib.Path | None = None, fail_on_status_code: bool = False, ) -> "_IsolatedAPIRequestContext": """Create an isolated API request context that does not sync with the browser. Args: base_url: Optional base URL for this context. extra_headers: Optional headers for this context. storage_state: Optional cookies to pre-load, as a dict, JSON file path, or Path. fail_on_status_code: If True, raise APIRequestFailure for non-2xx responses. Returns: An _IsolatedAPIRequestContext instance. """ cookies: list[dict] = [] if storage_state is not None: if isinstance(storage_state, (str, pathlib.Path)): file_path = pathlib.Path(storage_state) if not file_path.exists(): raise FileNotFoundError(f"Storage state file not found: {file_path}") try: with open(file_path) as f: state = json.load(f) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in storage state file {file_path}: {e}") from e except OSError as e: raise OSError(f"Cannot read storage state file {file_path}: {e}") from e else: state = storage_state cookies = list(state.get("cookies", [])) return _IsolatedAPIRequestContext( base_url=base_url, extra_headers=extra_headers, cookies=cookies, timeout=self._timeout, max_redirects=self._max_redirects, fail_on_status_code=fail_on_status_code, )
[docs] def get_storage_state(self, path: str | pathlib.Path | None = None) -> dict[str, Any]: """Export the current browser cookies as a storage state dict. Args: path: Optional file path to save the storage state as JSON. Returns: A dict with a "cookies" key containing the browser cookies. """ cookies = self._driver.get_cookies() state: dict[str, Any] = {"cookies": cookies} if path is not None: file_path = pathlib.Path(path) try: with open(file_path, "w") as f: json.dump(state, f, indent=2) except OSError as e: raise OSError(f"Cannot write storage state to {file_path}: {e}") from e return state
def _get_cookies_for_request(self, url: str) -> list[dict]: """Get matching browser cookies for the request URL.""" try: browser_cookies = self._driver.get_cookies() except Exception: logger.debug("Could not retrieve browser cookies", exc_info=True) return [] # Derive default domain from the browser's current page for host-only cookies default_domain = "" try: current = self._driver.current_url if current: default_domain = urllib.parse.urlparse(current).hostname or "" except Exception: logger.debug("Could not get current URL for host-only cookie matching", exc_info=True) return [c for c in browser_cookies if _cookie_matches(c, url, default_domain)] def _handle_response_cookies(self, set_cookie_headers: list[str], url: str) -> None: """Sync Set-Cookie headers back to the browser.""" parsed_url = urllib.parse.urlparse(url) for sc_header in set_cookie_headers: cookie = _parse_set_cookie(sc_header) if not cookie.get("name"): continue cookie.setdefault("domain", parsed_url.hostname or "") cookie.setdefault("path", "/") expiry = cookie.get("expiry") if expiry is not None and expiry <= int(time.time()): try: self._driver.delete_cookie(cookie["name"]) except Exception: pass continue try: self._driver.add_cookie(cookie) except Exception: logger.warning( "Could not sync cookie '%s' to browser (domain mismatch with current page)", cookie.get("name"), exc_info=True, )
class _IsolatedAPIRequestContext(_BaseRequestContext): """An isolated API request context that maintains its own cookie jar. Does not synchronize cookies with any browser session. """ def __init__( self, base_url: str = "", extra_headers: dict[str, str] | None = None, cookies: list[dict] | None = None, timeout: float = 30.0, max_redirects: int = 10, fail_on_status_code: bool = False, ) -> None: super().__init__( base_url=base_url, extra_headers=extra_headers, timeout=timeout, max_redirects=max_redirects, fail_on_status_code=fail_on_status_code, ) self._cookies: list[dict] = cookies or [] def get_storage_state(self) -> dict[str, Any]: """Return the current cookies as a storage state dict.""" return {"cookies": list(self._cookies)} def _get_cookies_for_request(self, url: str) -> list[dict]: """Get matching cookies from the internal jar.""" # For isolated contexts, use the request hostname as default domain default_domain = urllib.parse.urlparse(url).hostname or "" return [c for c in self._cookies if _cookie_matches(c, url, default_domain)] def _handle_response_cookies(self, set_cookie_headers: list[str], url: str) -> None: """Store Set-Cookie headers in the internal jar.""" parsed_url = urllib.parse.urlparse(url) now = int(time.time()) for sc_header in set_cookie_headers: cookie = _parse_set_cookie(sc_header) if not cookie.get("name"): continue cookie.setdefault("domain", parsed_url.hostname or "") cookie.setdefault("path", "/") # Cookies are unique by (name, domain, path) key = (cookie["name"], cookie.get("domain", ""), cookie.get("path", "/")) # Remove existing cookie with same key self._cookies = [ c for c in self._cookies if (c.get("name"), c.get("domain", ""), c.get("path", "/")) != key ] # Only store if not expired (Max-Age=0 or negative means delete) expiry = cookie.get("expiry") if expiry is not None and expiry <= now: continue self._cookies.append(cookie)