Source code for flask_limiter.wrappers

from __future__ import annotations

import typing
import weakref
from typing import Callable, Iterator, List, Optional, Sequence, Tuple, Union

from flask import request
from flask.wrappers import Response
from limits import RateLimitItem, parse_many
from limits.strategies import RateLimiter

if typing.TYPE_CHECKING:
    from .extension import Limiter


[docs]class RequestLimit: """ Provides details of a rate limit within the context of a request """ #: The instance of the rate limit limit: RateLimitItem #: The full key for the request against which the rate limit is tested key: str #: Whether the limit was breached within the context of this request breached: bool def __init__( self, extension: Limiter, limit: RateLimitItem, request_args: List[str], breached: bool, ) -> None: self.extension: weakref.ProxyType[Limiter] = weakref.proxy(extension) self.limit = limit self.request_args = request_args self.key = limit.key_for(*request_args) self.breached = breached self._window: Optional[Tuple[int, int]] = None @property def limiter(self) -> RateLimiter: return typing.cast(RateLimiter, self.extension.limiter) @property def window(self) -> Tuple[int, int]: if not self._window: self._window = self.limiter.get_window_stats(self.limit, *self.request_args) return self._window @property def reset_at(self) -> int: """Timestamp at which the rate limit will be reset""" return int(self.window[0] + 1) @property def remaining(self) -> int: """Quantity remaining for this rate limit""" return self.window[1]
class Limit: """ simple wrapper to encapsulate limits and their context """ def __init__( self, limit: RateLimitItem, key_func: Callable[[], str], scope: Optional[Union[str, Callable[[str], str]]], per_method: bool, methods: Optional[Sequence[str]], error_message: Optional[str], exempt_when: Optional[Callable[[], bool]], override_defaults: Optional[bool], deduct_when: Optional[Callable[[Response], bool]], on_breach: Optional[Callable[[RequestLimit], Optional[Response]]], cost: Union[Callable[[], int], int], ) -> None: self.limit = limit self.key_func = key_func self.__scope = scope self.per_method = per_method self.methods = methods self.error_message = error_message self.exempt_when = exempt_when self.override_defaults = override_defaults self.deduct_when = deduct_when self.on_breach = on_breach self._cost = cost @property def is_exempt(self) -> bool: """Check if the limit is exempt.""" if self.exempt_when: return self.exempt_when() return False @property def scope(self) -> Optional[str]: return ( self.__scope(request.endpoint or "") if callable(self.__scope) else self.__scope ) @property def cost(self) -> int: if isinstance(self._cost, int): return self._cost return self._cost() @property def method_exempt(self) -> bool: """Check if the limit is not applicable for this method""" return self.methods is not None and request.method.lower() not in self.methods def args_for(self, endpoint: str, key: str, method: Optional[str]) -> List[str]: scope = self.scope or endpoint if self.per_method: assert method scope += f":{method.upper()}" args = [key, scope] return args class LimitGroup: """ represents a group of related limits either from a string or a callable that returns one """ def __init__( self, limit_provider: Union[Callable[[], str], str], key_function: Callable[[], str], scope: Optional[Union[str, Callable[[str], str]]], per_method: bool, methods: Optional[Sequence[str]], error_message: Optional[str], exempt_when: Optional[Callable[[], bool]], override_defaults: Optional[bool], deduct_when: Optional[Callable[[Response], bool]], on_breach: Optional[Callable[[RequestLimit], Optional[Response]]], cost: Optional[Union[Callable[[], int], int]], ) -> None: self.__limit_provider = limit_provider self.__scope = scope self.key_function = key_function self.per_method = per_method self.methods = methods and [m.lower() for m in methods] or methods self.error_message = error_message self.exempt_when = exempt_when self.override_defaults = override_defaults self.deduct_when = deduct_when self.on_breach = on_breach self.cost = cost or 1 def __iter__(self) -> Iterator[Limit]: limit_items = parse_many( self.__limit_provider() if callable(self.__limit_provider) else self.__limit_provider ) for limit in limit_items: yield Limit( limit, self.key_function, self.__scope, self.per_method, self.methods, self.error_message, self.exempt_when, self.override_defaults, self.deduct_when, self.on_breach, self.cost, )