aboutsummaryrefslogtreecommitdiff
path: root/sjdbmk/common.py
blob: 642b2518c56073dc9c27d86f229cf3c5d8c5a3ae (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
#
# Common functions for the Daily Bulletin Build System
# Copyright (C) 2024 Runxi Yu <https://runxiyu.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

from typing import Any, Iterable, Iterator
import logging
import re
import base64
import shutil

import requests
import msal  # type: ignore


def acquire_token(
    graph_client_id: str,
    graph_authority: str,
    graph_username: str,
    graph_password: str,
    graph_scopes: list[str],
) -> str:
    app = msal.PublicClientApplication(
        graph_client_id,
        authority=graph_authority,
    )
    result = app.acquire_token_by_username_password(
        graph_username, graph_password, scopes=graph_scopes
    )

    if "access_token" in result:
        assert isinstance(result["access_token"], str)
        return result["access_token"]
    raise ValueError("Authentication error in password login")


def search_mail(token: str, query_string: str) -> list[dict[str, Any]]:
    hits = requests.post(
        "https://graph.microsoft.com/v1.0/search/query",
        headers={"Authorization": "Bearer " + token},
        json={
            "requests": [
                {
                    "entityTypes": ["message"],
                    "query": {"queryString": query_string},
                    "from": 0,
                    "size": 15,
                    "enableTopResults": True,
                }
            ]
        },
        timeout=20,
    ).json()["value"][0]["hitsContainers"][0]["hits"]
    assert isinstance(hits, list)
    assert isinstance(hits[0], dict)
    return hits


def encode_sharing_url(url: str) -> str:
    return "u!" + base64.urlsafe_b64encode(url.encode("utf-8")).decode("ascii").rstrip(
        "="
    )


def download_share_url(
    token: str, url: str, local_filename: str, chunk_size: int = 65536
) -> None:

    download_direct_url = requests.get(
        "https://graph.microsoft.com/v1.0/shares/%s/driveItem"
        % encode_sharing_url(url),
        headers={"Authorization": "Bearer " + token},
        timeout=20,
    ).json()["@microsoft.graph.downloadUrl"]

    with requests.get(
        download_direct_url,
        headers={
            "Authorization": "Bearer %s" % token,
            "Accept-Encoding": "identity",
        },
        stream=True,
        timeout=20,
    ) as r:
        with open(local_filename, "wb") as fd:
            shutil.copyfileobj(r.raw, fd)
            fd.flush()


def filter_mail_results_by_sender(
    original: Iterable[dict[str, Any]], sender: str
) -> Iterator[dict[str, Any]]:
    for hit in original:
        if (
            hit["resource"]["sender"]["emailAddress"]["address"].lower()
            == sender.lower()
        ):
            yield hit


# TODO: Potentially replace this with a pattern-match based on strptime().
def filter_mail_results_by_subject_regex_groups(
    original: Iterable[dict[str, Any]],
    subject_regex: str,
    subject_regex_groups: Iterable[int],
) -> Iterator[tuple[dict[str, Any], list[str]]]:
    for hit in original:
        logging.debug("Trying %s" % hit["resource"]["subject"])
        matched = re.compile(subject_regex).match(hit["resource"]["subject"])
        if matched:
            yield (hit, [matched.group(group) for group in subject_regex_groups])


class DailyBulletinError(Exception):
    pass