Please enable Javascript to view the contents

如何在 Django 中任意安全获取 request

 ·  ☕ 2 分钟

在 Django 中,request 包含了一次请求的全部信息。后端处理逻辑经常需要用到 request 中的信息。比如, DRF 框架中想要随时能够获取到 request,或者将一些参数全局传递。Django 第三方 App 中有一些工具可以满足要求,但它们并不是安全可靠的。意思是,如果 Django 启动时,使用了多线程或协程,在获取 request 时,可能会发生错误。这显然是不能接受的。下面是一个安全可靠的实现版本,让你在任意位置都能获取 request 对象。

1. 实现

utils/local.py 文件

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# -*- coding: utf-8 -*-

"""Thread-local/Greenlet-local objects

Thread-local/Greenlet-local objects support the management of
thread-local/greenlet-local data. If you have data that you want
to be local to a thread/greenlet, simply create a
thread-local/greenlet-local object and use its attributes:

  >>> mydata = Local()
  >>> mydata.number = 42
  >>> mydata.number
  42
  >>> hasattr(mydata, 'number')
  True
  >>> hasattr(mydata, 'username')
  False

  Reference :
  from threading import local
"""
try:
    from greenlet import getcurrent as get_ident
except ImportError:
    try:
        from thread import get_ident
    except ImportError:
        from _thread import get_ident

__all__ = ["local", "Local"]


class Localbase(object):

    __slots__ = ('__storage__', '__ident_func__')

    def __new__(cls, *args, **kwargs):
        self = object.__new__(cls, *args, **kwargs)
        object.__setattr__(self, '__storage__', {})
        object.__setattr__(self, '__ident_func__', get_ident)
        return self


class Local(Localbase):

    def __iter__(self):
        ident = self.__ident_func__()
        return iter(self.__storage__[ident].items())

    def __release_local__(self):
        self.__storage__.pop(self.__ident_func__(), None)

    def __getattr__(self, name):
        ident = self.__ident_func__()
        try:
            return self.__storage__[ident][name]
        except KeyError:
            raise AttributeError(name)

    def __setattr__(self, name, value):
        if name in ('__storage__', '__ident_func__'):
            raise AttributeError(
                "%r object attribute '%s' is read-only"
                % (self.__class__.__name__, name))

        ident = self.__ident_func__()
        storage = self.__storage__
        try:
            storage[ident][name] = value
        except KeyError:
            storage[ident] = {name: value}

    def __delattr__(self, name):
        if name in ('__storage__', '__ident_func__'):
            raise AttributeError(
                "%r object attribute '%s' is read-only"
                % (self.__class__.__name__, name))

        ident = self.__ident_func__()
        try:
            del self.__storage__[ident][name]
            if len(self.__storage__[ident]) == 0:
                self.__release_local__()
        except KeyError:
            raise AttributeError(name)

local = Local()


if __name__ == '__main__':
    def display(id):
        # import time
        local.id = id
        for i in range(3):
            print get_ident(), local.id, "\n"
            # time.sleep(1)

    def gree(id):
        import gevent
        t = []
        for i in range(10):
            t.append(gevent.spawn(display, "%s-%s" % (id, i)))
        gevent.joinall(t)

    # test one
    # l1 = Local()
    # l2 = Local()
    # l.xxx = 1
    # print l.xxx
    # print l1.xxx
    # print l2.xxx

    # test two
    # import gevent
    # t = []
    # for i in range(10):
    #     g = gevent.spawn(display, i)
    #     t.append(g)
    # gevent.joinall(t)

    # test three
    import threading
    t = []
    for i in range(10):
        t.append(threading.Thread(target=gree, args=(i,)))

    [th.start() for th in t]
    [th.join() for th in t]

utils/request_middlewares.py 文件

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- coding: utf-8 -*-
from django.dispatch import Signal
from django.conf import settings
from utils.local import local


class AccessorSignal(Signal):
    allowed_receiver = 'utils.request_middlewares.RequestProvider'

    def __init__(self, providing_args=None):
        Signal.__init__(self, providing_args)

    def connect(self, receiver, sender=None, weak=True, dispatch_uid=None):
        receiver_name = '.'.join(
            [receiver.__class__.__module__, receiver.__class__.__name__]
        )
        if receiver_name != self.allowed_receiver:
            raise Exception(
                u"%s is not allowed to connect" % receiver_name)
        if not self.receivers:
            Signal.connect(self, receiver, sender, weak, dispatch_uid)


request_accessor = AccessorSignal()


class RequestProvider(object):
    """
    @summary: request事件接收者
    """

    def __init__(self):
        request_accessor.connect(self)

    def process_request(self, request):
        """
            这里可以在 request 上添加自定义的一些数据、处理逻辑
        """
        local.current_request = request
        return None

    def process_view(self, request, view_func, view_args, view_kwargs):
        your_args = view_kwargs.get("your_args", "")
        if not your_args:
            your_args = (request.POST.get('your_args') or
                      request.GET.get('your_args')) or ""
        request.your_args = your_args

    def process_response(self, request, response):
        if hasattr(local, 'current_request'):
            assert request is local.current_request
            del local.current_request

        return response

    def __call__(self, **kwargs):
        if not hasattr(local, 'current_request'):
            raise Exception(
                u"get_request can't be called in a new thread.")
        return local.current_request


def get_request():
    if hasattr(local, 'current_request'):
        return local.current_request
    else:
        raise Exception(u"get_request: current thread hasn't request.")


def get_x_request_id():
    x_request_id = ''
    http_request = get_request()
    if hasattr(http_request, 'META'):
        meta = http_request.META
        x_request_id = (meta.get('HTTP_X_REQUEST_ID', '')
                        if isinstance(meta, dict) else '')
    return x_request_id

2. 使用

1
2
3
4
MIDDLEWARE_CLASSES = (
    ...
    'utils.request_middlewares.RequestProvider',
    ... )
1
2
3
4
5
from utils.request_middlewares import local

def my_function():
    local.current_request
    pass

微信公众号
作者
微信公众号