Skip to content

Commit

Permalink
Merge pull request #7 from Azureblade3808/v0.2.x
Browse files Browse the repository at this point in the history
V0.2.1
  • Loading branch information
Azureblade3808 authored Sep 12, 2024
2 parents 782239a + 8346182 commit d330df1
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 51 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ There are two designed ways to use `defer`. You may use either of them, or mix t
... defer and print(3)
... defer and print(4)

>>> import sys
>>> if sys.version_info < (3, 12):
... from deferrer import defer_scope
... f = defer_scope(f)

>>> f()
2
4
Expand All @@ -46,6 +51,11 @@ There are two designed ways to use `defer`. You may use either of them, or mix t
... defer(print)(3)
... defer(print)(4)

>>> import sys
>>> if sys.version_info < (3, 12):
... from deferrer import defer_scope
... f = defer_scope(f)

>>> f()
2
4
Expand Down
2 changes: 1 addition & 1 deletion deferrer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.2.0"
__version__ = "0.2.1"

from .__public__ import *
66 changes: 51 additions & 15 deletions deferrer/_defer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@

__all__ = ["defer"]

import sys
from collections.abc import Callable
from types import CellType, FunctionType
from typing import Any, Final, Literal, cast, final
from typing import Any, Final, Generic, Literal, ParamSpec, cast, final
from warnings import warn

from ._code_location import get_code_location
from ._defer_scope import ensure_deferred_actions
from ._deferred_actions import DeferredAction
from ._opcode import Opcode
from ._code_location import get_code_location
from ._frame import get_outer_frame
from ._opcode import Opcode
from ._sequence_matching import WILDCARD, sequence_has_prefix

_P = ParamSpec("_P")

_MISSING = cast("Any", object())

Expand All @@ -23,13 +27,19 @@ class Defer:
Examples
--------
>>> import sys
>>> from deferrer import defer_scope
>>> def f_0():
... defer and print(0)
... defer and print(1)
... print(2)
... defer and print(3)
... defer and print(4)
>>> if sys.version_info < (3, 12):
... f_0 = defer_scope(f_0)
>>> f_0()
2
4
Expand All @@ -44,6 +54,9 @@ class Defer:
... defer(print)(3)
... defer(print)(4)
>>> if sys.version_info < (3, 12):
... f_1 = defer_scope(f_1)
>>> f_1()
2
4
Expand All @@ -66,24 +79,46 @@ def __bool__() -> Literal[False]:
frame = get_outer_frame()

# The usage is `defer and ...` and the typical instructions should be like:
#
# ```
# LOAD_GLOBAL ? (defer)
# COPY
# --> POP_JUMP_IF_FALSE ?
# POP_TOP
# <???>
# ```
# (Python 3.12)
#
# ```
# LOAD_GLOBAL ? (defer)
# --> JUMP_IF_FALSE_OR_POP ?
# <???>
# ```
# (Python 3.11)
#
# The current instruction is at the line prefixed by "-->", and the "<???>" part
# stands for the RHS part in `defer and ...`.
if sys.version_info >= (3, 12):
expected_code_bytes_prefix = (
Opcode.POP_JUMP_IF_FALSE,
WILDCARD,
Opcode.POP_TOP,
0,
)
rhs_part_offset = 2
else:
expected_code_bytes_prefix = (
Opcode.JUMP_IF_FALSE_OR_POP,
WILDCARD,
)
rhs_part_offset = 0

code = frame.f_code
code_bytes = code.co_code
i_code_byte = frame.f_lasti
if not (
True
and len(code_bytes) - i_code_byte >= 4
and code_bytes[i_code_byte + 0] == Opcode.POP_JUMP_IF_FALSE
and code_bytes[i_code_byte + 2] == Opcode.POP_TOP
and code_bytes[i_code_byte + 3] == 0

if not sequence_has_prefix(
code_bytes[i_code_byte:], expected_code_bytes_prefix
):
code_location = get_code_location(frame)
message = (
Expand Down Expand Up @@ -137,11 +172,11 @@ def __bool__() -> Literal[False]:
)
dummy_consts += (value,)

# Copy the bytecode of the `...` part in `defer and ...` into the dummy
# Copy the bytecode of the RHS part in `defer and ...` into the dummy
# function.
n_skipped_bytes = code_bytes[i_code_byte + 1] * 2
dummy_code_bytes += code_bytes[
(i_code_byte + 4) : (i_code_byte + 2 + n_skipped_bytes)
(i_code_byte + 2 + rhs_part_offset) : (i_code_byte + 2 + n_skipped_bytes)
]

# The dummy function should return something. The simplest way is to return
Expand All @@ -155,6 +190,7 @@ def __bool__() -> Literal[False]:
co_kwonlyargcount=0,
co_code=dummy_code_bytes,
co_consts=dummy_consts,
co_linetable=bytes(),
)

new_function = FunctionType(
Expand All @@ -168,7 +204,7 @@ def __bool__() -> Literal[False]:
return False

@staticmethod
def __call__[**P](callable: Callable[P, Any], /) -> Callable[P, None]:
def __call__(callable: Callable[_P, Any], /) -> Callable[_P, None]:
"""
Converts a callable into a deferred callable.
Expand Down Expand Up @@ -199,19 +235,19 @@ def perform(self, /) -> None:


@final
class _DeferredCallable[**P](DeferredAction):
class _DeferredCallable(DeferredAction, Generic[_P]):
_body: Final[Callable[..., Any]]
_code_location: Final[str]

_args_and_kwargs: tuple[tuple[Any, ...], dict[str, Any]] | None

def __init__(self, body: Callable[P, Any], /, code_location: str) -> None:
def __init__(self, body: Callable[_P, Any], /, code_location: str) -> None:
self._body = body
self._code_location = code_location

self._args_and_kwargs = None

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None:
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> None:
if self._args_and_kwargs is not None:
raise RuntimeError("`defer(...)` gets further called more than once.")

Expand Down
15 changes: 13 additions & 2 deletions deferrer/_defer_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
]

import operator
import sys
from collections.abc import Callable, Iterable, Iterator
from contextlib import AbstractContextManager
from functools import update_wrapper
from types import FrameType
from typing import Any, Final, Generic, ParamSpec, TypeVar, cast, final, overload

from ._deferred_actions import DeferredActions
from ._frame import get_outer_frame, is_class_frame, is_global_frame
from ._frame import get_current_frame, get_outer_frame, is_class_frame, is_global_frame

_Wrapped_t = TypeVar("_Wrapped_t")

Expand Down Expand Up @@ -54,6 +55,12 @@ def ensure_deferred_actions(frame: FrameType) -> DeferredActions:
if deferred_actions is not None:
return deferred_actions

if sys.version_info < (3, 12):
raise RuntimeError(
"cannot inject deferred actions into local scope"
+ " with Python older than 3.12"
)

local_scope = frame.f_locals

deferred_actions = local_scope.get(_LOCAL_KEY)
Expand Down Expand Up @@ -123,7 +130,7 @@ def __call__(
) -> _R:
wrapped = self._wrapped

frame = get_outer_frame()
frame = get_current_frame()
deferred_actions = DeferredActions()

_callable_deferred_actions_recorder.add(frame, deferred_actions)
Expand Down Expand Up @@ -170,6 +177,8 @@ def remove(
__ = self._internal_dict.pop(outer_frame)
assert __ is deferred_actions

deferred_actions.drain()

def find(self, frame: FrameType, /) -> DeferredActions | None:
outer_frame = frame.f_back
assert outer_frame is not None
Expand Down Expand Up @@ -206,6 +215,8 @@ def remove(self, frame: FrameType, deferred_actions: DeferredActions, /) -> None
if len(deferred_actions_list) == 0:
del internal_dict[frame]

deferred_actions.drain()

def find(self, frame: FrameType, /) -> DeferredActions | None:
deferred_actions_list = self._internal_dict.get(frame)
if deferred_actions_list is None:
Expand Down
26 changes: 23 additions & 3 deletions deferrer/_frame.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

__all__ = [
"get_current_frame",
"get_outer_frame",
"is_class_frame",
"is_global_frame",
Expand All @@ -13,6 +14,26 @@
from ._opcode import Opcode


def get_current_frame() -> FrameType:
"""
Returns the frame of the caller.
Examples
--------
>>> def foo(): # L0
... frame = get_current_frame()
... print(
... frame.f_code.co_name,
... frame.f_lineno - frame.f_code.co_firstlineno, # L4
... )
>>> foo()
foo 4
"""

return sys._getframe(1) # pyright: ignore[reportPrivateUsage]


def get_outer_frame() -> FrameType:
"""
Returns the frame of the caller of caller.
Expand All @@ -22,7 +43,7 @@ def get_outer_frame() -> FrameType:
>>> def foo(): # L0
... def inner():
... frame = get_outer_frame()
... print (
... print(
... frame.f_code.co_name,
... frame.f_lineno - frame.f_code.co_firstlineno,
... )
Expand All @@ -32,8 +53,7 @@ def get_outer_frame() -> FrameType:
foo 7
"""

frame = sys._getframe(2) # pyright: ignore[reportPrivateUsage]
return frame
return sys._getframe(2) # pyright: ignore[reportPrivateUsage]


def is_global_frame(frame: FrameType, /) -> bool:
Expand Down
8 changes: 7 additions & 1 deletion deferrer/_opcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = ["Opcode"]

import sys
from dis import opmap
from enum import IntEnum

Expand All @@ -14,9 +15,14 @@ class Opcode(IntEnum):
COPY_FREE_VARS = opmap["COPY_FREE_VARS"]
LOAD_CONST = opmap["LOAD_CONST"]
LOAD_NAME = opmap["LOAD_NAME"]
POP_JUMP_IF_FALSE = opmap["POP_JUMP_IF_FALSE"]
POP_TOP = opmap["POP_TOP"]
RESUME = opmap["RESUME"]
RETURN_VALUE = opmap["RETURN_VALUE"]
STORE_FAST = opmap["STORE_FAST"]
STORE_NAME = opmap["STORE_NAME"]

if (3, 12) <= sys.version_info < (3, 13):
POP_JUMP_IF_FALSE = opmap["POP_JUMP_IF_FALSE"]

if (3, 11) <= sys.version_info < (3, 12):
JUMP_IF_FALSE_OR_POP = opmap["JUMP_IF_FALSE_OR_POP"]
25 changes: 25 additions & 0 deletions deferrer/_sequence_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

__all__ = [
"WILDCARD",
"sequence_has_prefix",
]

from collections.abc import Sequence
from typing import Any


class _Wildcard:
"""
An object that equals any object.
"""

def __eq__(self, other: object, /) -> bool:
return True


WILDCARD = _Wildcard()


def sequence_has_prefix(sequence: Sequence[Any], prefix: Sequence[Any], /) -> bool:
return tuple(sequence[: len(prefix)]) == tuple(prefix)
Loading

0 comments on commit d330df1

Please sign in to comment.