diff --git a/advent/parser/parser.py b/advent/parser/parser.py index 9772844..057136b 100644 --- a/advent/parser/parser.py +++ b/advent/parser/parser.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass from functools import reduce from itertools import chain -from typing import Any, Callable, Generic, Self, TypeVar, overload +from typing import Any, Callable, Generic, Iterator, Self, TypeVar, overload import unicodedata from .result import Result @@ -42,7 +42,7 @@ class ParserInput: return f'{self.input[self.start-3:self.start-1]}->[{self.input[self.start:]}]' -ParserResult = Result[tuple[ParserInput, T]] +ParserResult = Iterator[tuple[ParserInput, T]] ParserFunc = Callable[[ParserInput], ParserResult[T]] @@ -51,16 +51,30 @@ class P(Generic[T]): self.func = func def parse(self, s: str, i: int = 0) -> Result[T]: - result = self.func(ParserInput(s, i)) - return result.fmap(lambda pr: pr[1]) + all_results = self.func(ParserInput(s, i)) + try: + _, result = next(all_results) + return Result.of(result) + except StopIteration: + return Result.fail("No result") + + def parse_multi(self, s: str, i: int = 0) -> Iterator[T]: + return (v for _, v in self.func(ParserInput(s, i))) @staticmethod def pure(value: T) -> P[T]: - return P(lambda parserPos: Result.of((parserPos, value))) + def inner(parserPos: ParserInput) -> ParserResult[T]: + yield (parserPos, value) + return P(inner) @staticmethod - def fail(text: str) -> P[Any]: - return P(lambda pp: Result.fail(f'@: {pp}\ntext: {text}')) + def fail() -> P[Any]: + def inner(_: ParserInput) -> ParserResult[Any]: + if False: + yield + pass + + return P(inner) @staticmethod def _fix(p1: Callable[[P[Any]], P[T]]) -> P[T]: @@ -72,24 +86,22 @@ class P(Generic[T]): return self def bind(self, bind_func: Callable[[T], P[TR]]) -> P[TR]: - def inner(x: tuple[ParserInput, T]): - parserPos, v = x - return bind_func(v).func(parserPos) - - return P(lambda parserPos: self.func(parserPos).bind(inner)) + def inner(parserPos: ParserInput) -> ParserResult[TR]: + return (r for rs in (bind_func(v).func(pp) + for pp, v in self.func(parserPos)) for r in rs) + return P(inner) def fmap(self, map_func: Callable[[T], TR]) -> P[TR]: - def inner(x: tuple[ParserInput, T]): - parserPos, v = x - return (parserPos, map_func(v)) - return P(lambda parserPos: self.func(parserPos).fmap(inner)) + def inner(parserPos: ParserInput) -> ParserResult[TR]: + return ((pp, map_func(v)) for pp, v in self.func(parserPos)) + return P(inner) def safe_fmap(self, map_func: Callable[[T], TR]) -> P[TR]: def inner(value: T) -> P[TR]: try: return P.pure(map_func(value)) - except Exception as e: - return P.fail(f'Failed with {e}') + except Exception: + return P.fail() return self.bind(inner) def replace(self, value: TR) -> P[TR]: @@ -115,8 +127,7 @@ class P(Generic[T]): return P.either(self.some(), P.pure([])) def satisfies(self, pred: Callable[[T], bool], failtext: str) -> P[T]: - return self.bind(lambda v: P.pure(v) if pred( - v) else P.fail(f'Does not satisfy: {failtext}')) + return self.bind(lambda v: P.pure(v) if pred(v) else P.fail()) def optional(self) -> P[T | None]: return P.either(self, P.pure(None)) @@ -133,13 +144,15 @@ class P(Generic[T]): return P.map2(p1, p2, lambda v1, _: v1) @staticmethod - def no_match(p: P[Any], failtext: str) -> P[tuple[()]]: - def inner(pp: ParserInput) -> ParserResult[tuple[()]]: - result = p.func(pp) - if result.is_fail(): - return Result.of((pp, ())) - else: - return Result.fail(failtext) + def no_match(p: P[Any]) -> P[tuple[()]]: + def inner(parserPos: ParserInput) -> ParserResult[tuple[()]]: + result = p.func(parserPos) + try: + next(result) + # Silently yields nothing so is an empty Generator + except StopIteration: + yield (parserPos, ()) + return P(inner) @ staticmethod @@ -233,8 +246,8 @@ class P(Generic[T]): @ staticmethod def either(p1: P[T1], p2: P[T2], /) -> P[T1 | T2]: def inner(parserPos: ParserInput): - result = p1.func(parserPos) - return result if result.is_ok() else p2.func(parserPos) + yield from p1.func(parserPos) + yield from p2.func(parserPos) return P(inner) @@ -256,7 +269,7 @@ class P(Generic[T]): @ staticmethod def choice(*ps: P[Any]) -> P[Any]: - return reduce(P.either, ps, P.fail('No Choice matched')) + return reduce(P.either, ps, P.fail()) @ staticmethod def choice_same(*ps: P[T]) -> P[T]: @@ -264,12 +277,17 @@ class P(Generic[T]): @ staticmethod def any_char() -> P[str]: - return P(lambda pp: Result.of(pp.step()) if not pp.is_eof() - else Result.fail('At EOF')) + def inner(parserPos: ParserInput) -> ParserResult[str]: + if not parserPos.is_eof(): + yield parserPos.step() + return P(inner) @ staticmethod def eof() -> P[tuple[()]]: - return P.no_match(P.any_char(), 'Not at eof') + def inner(parserPos: ParserInput) -> ParserResult[tuple[()]]: + if parserPos.is_eof(): + yield parserPos, () + return P(inner) @ staticmethod def is_char(cmp: str) -> P[str]: @@ -277,7 +295,7 @@ class P(Generic[T]): @staticmethod def is_not_char(s: str) -> P[tuple[()]]: - return P.no_match(P.is_char(s), f'Did match {s}') + return P.no_match(P.is_char(s)) @ staticmethod def char_by_func(cmp: Callable[[str], bool], failtext: str) -> P[str]: @@ -298,12 +316,12 @@ class P(Generic[T]): @ staticmethod def is_decimal(num: int) -> P[str]: return P.any_decimal().bind( - lambda c: P.pure(c) if unicodedata.decimal(c) == num else P.fail(f'Not {num}')) + lambda c: P.pure(c) if unicodedata.decimal(c) == num else P.fail()) @ staticmethod def is_not_decimal(num: int) -> P[str]: return P.any_decimal().bind( - lambda c: P.pure(c) if unicodedata.decimal(c) != num else P.fail(f'Is {num}')) + lambda c: P.pure(c) if unicodedata.decimal(c) != num else P.fail()) @ staticmethod def lower() -> P[str]: @@ -323,7 +341,7 @@ class P(Generic[T]): @ staticmethod def unsigned() -> P[int]: - return P.either(P.fst(P.is_decimal(0), P.no_match(P.any_decimal(), 'starting Zero')), + return P.either(P.fst(P.is_decimal(0), P.no_match(P.any_decimal())), P.map2(P.is_not_decimal(0), P.any_decimal().many(), lambda f, s: f + ''.join(s)) ).fmap(int) diff --git a/advent/parser/test_parser.py b/advent/parser/test_parser.py index f2c0c8b..779b45a 100644 --- a/advent/parser/test_parser.py +++ b/advent/parser/test_parser.py @@ -130,7 +130,7 @@ def test_seq_seq(): def test_not(): input = 'a' - parser = P.snd(P.no_match(P.is_char('!'), 'found !'), P.is_char('a')) + parser = P.snd(P.no_match(P.is_char('!')), P.is_char('a')) expected = 'a' result = parser.parse(input).get() assert result == expected @@ -138,3 +138,30 @@ def test_not(): input2 = '!' result2 = parser.parse(input2) assert result2.is_fail() + + +def test_multi(): + input = 'aa' + parser = P.is_char('a').many() + expected = [['a', 'a'], ['a'], []] + result = list(parser.parse_multi(input)) + assert result == expected + + +def test_either(): + input = 'aab' + parser = P.either( + P.seq( + P.is_char('a').many(), P.string('b')), P.seq( + P.string('a'), P.string('ab'))) + expected = [(['a', 'a'], 'b'), ('a', 'ab')] + result = list(parser.parse_multi(input)) + assert result == expected + + +def test_seq_eof(): + input = 'aa' + parser = P.seq(P.is_char('a').many(), P.eof()) + expected = [(['a', 'a'], ())] + result = list(parser.parse_multi(input)) + assert result == expected