Mkdir700's Note

Mkdir700's Note

Python 泛型 - 如何在实例方法中获取泛型参数T的类型?

2023-02-23

#Python #Python泛型

先上解决方法: https://stackoverflow.com/questions/57706180/generict-base-class-how-to-get-type-of-t-from-within-instance

再来简单分析下源码。

talk is cheap, show me the code.

from typing import Dict
Dict[str, int]

Dict只是一个类型,并不是字典类,但是我们可以通过一些方法,拿到其真正意义上的类。

typing 库提供了 get_argsget_origin 函数。

get_args

顾名思义,获取参数。这里的获取参数指的是获取类的泛型参数。

什么意思?

看一看 dict 的代码注解,可以看到 dict 支持泛型的,接受两个泛型参数:_KT_VT
这就是我们可以通过 Dict[str, int] 的方式对字典内键值对的类型进行更加具体标注原因。

所以,Dict[str, int]str, int 就是 Dict 的泛型参数。

通过 get_args 就可以获取到内部的泛型参数,就像这样:


get_origin

获取原始,原始什么?就是获取类型的原始类。

Dict 本身仅仅一个类型,它并不支持去实例化一个字典对象。

如何通过 Dict 而拿到 dict 呢,则就需要使用 get_origin 获取原始类。

就像这样:

值得注意的是:get_args get_origin 仅支持内置的类型

这里再贴一下,官方文档的描述

get_args、get_origin 为泛型类型与特殊类型形式提供了基本的内省功能。
对于 X[Y, Z, ...] 形式的类型对象,这些函数返回 X 与 (Y, Z, ...)。如果 X 是内置对象或 collections class 的泛型别名, 会将其标准化为原始类。如果 X 是包含在其他泛型类型中的联合类型或 Literal(Y, Z, ...) 的顺序会因类型缓存,而与原始参数 [Y, Z, ...] 的顺序不同。对于不支持的对象会相应地返回 None 或 ()

在实例方法中获取原始类及泛型参数的数据类型

接下来看另一种情况,我希望在 Demo 类内部,获取 T 它所对应的真实类,例如:

from typing import Generic, TypeVar, get_args

T = TypeVar("T")
class Demo(Generic[T]):
	def __init__(self):
		print(get_args(self.__class__))

Demo[int]()

看着没什么问题,其实这里会打印空内容。

为什么呢?

self.__class__ 确实获取到了 Demo 类,但是这个 Demo 类并不是最原始的样子。

get_args(self.__class__) 就等同于 get_args(Demo),自然拿不到泛型参数。

我们想要的语句是长这个样子的 get_args(Demo[int]),那么,现在的问题就转换成了如何在 Demo 类内部获取到 Demo[int],我暂且就叫它原始类吧。

先上解决方法,在方法内部调用 self.__orig_class__ 即可获取到原始类。

from typing import Generic, TypeVar, get_args  
  
T = TypeVar("T")  
  
  
class Demo(Generic[T]):  
    def __init__(self):  
        pass  
  
    def test(self):
        c = get_args(self.__orig_class__)[0]
        assert c is int  
  
  
demo = Demo[int]()  
demo.test()

在上面这个示例中,当调用 test 方法时,在该方法内部即可知道泛型参数 T 所对应的类型是什么了。

这里有一个疑问, 为什么`get_args(self.__orig_class__)[0]`写在了`test`方法内,而不是`__init__`初始化方法内。

先说结论:通过以上方法获取泛型参数的类型,只能在该泛型类初始化完成之后才可以使用,即必须在`__init__, __new__`执行后才可调用。

简要分析 Generic 源码

接下来,让我一个人墨迹一会儿,我会简单分析 Generic 的源码,看一看为什么必须在 __init__, __new__ 之后才可以使用。

再多提一句,对于类本身是没有 __orig_class__ 这个属性的,但是为什么我们又可以使用它。

简单点说就是,__orig_class__ 是后来加上的,最初并没有做初始化,如下图 Pycharm 提示了该类不存在 __orig_class__ 属性。

在分析代码之前,可以再看下 Generic[T] 这个写法,它有这么一个中括号的。这个符号在 Python 就是一个语法糖。我们知道,列表对象可以通过 lst[0] 获取到对应下标的元素,字典对象可以通过 d[key] 获取到对应 key 的值,这都是因为列表类和字典类了实现了 __getitem__ 魔术方法。

对于列表和字典,它们都是已经被实例化的对象,而 Generic 是一个类,所以对于类同样支持 [] 语法糖的魔术方法,叫做 __class_getitem__ ,方法名也是比较好记住的,无非就是加了 __class__ 前缀。

现在我们就跳到 Generic 类里面,找到 __class__getitem__ 方法,为了方便浏览,我在以下代码中写注释了。

# 缓存泛型类
@_tp_cache
def __class_getitem__(cls, params):
	# params 很好理解,就是我们传入的泛型参数,`Generic[T]`中 T 就是这个 params
	if not isinstance(params, tuple):
		params = (params,)
	if not params and cls is not Tuple:
		raise TypeError(
			f"Parameter list to {cls.__qualname__}[...] cannot be empty")
	msg = "Parameters to generic types must be types."
	# 类型检查
	params = tuple(_type_check(p, msg) for p in params

	# 只有 Generic 极其子类才可以使用泛型 TypeVar
	if cls in (Generic, Protocol):
		# Generic and Protocol can only be subscripted with unique type variables.
		# 判断所有泛型参数都是 TypeVar 的实例
		if not all(isinstance(p, TypeVar) for p in params):
			raise TypeError(
				f"Parameters to {cls.__name__}[...] must all be type variables")
		if len(set(params)) != len(params):
			raise TypeError(
				f"Parameters to {cls.__name__}[...] must all be unique")
	else:
		# Subscripting a regular Generic subclass.
		_check_generic(cls, params)

	# 重点
	return _GenericAlias(cls, params)

我们关注 __class_getitem__ 的返回结果,返回了 _GenericAlias 的实例,接收两个参数:clsparams。这里的 cls 指的是 Generic 或其子类,params 就是泛型参数。

重点来了,对于 class Demo(Generic[T]) 而言,我们并不是继承至 Generic 而是 _GenericAlias()

_GenericAlias 是什么?看看它的初始化方法,如下:

def __init__(self, origin, params, *, inst=True, special=False, name=None):
	self._inst = inst
	self._special = special
	if special and name is None:
		orig_name = origin.__name__
		name = _normalize_alias.get(orig_name, orig_name)
	self._name = name
	if not isinstance(params, tuple):
		params = (params,)
	# origin 就是 Generic 或继承自它的子类
	self.__origin__ = origin

	self.__args__ = tuple(... if a is _TypingEllipsis else
						  () if a is _TypingEmpty else
						  a for a in params)
	# parmas 转换成了 self.__parameters__
	self.__parameters__ = _collect_type_vars(params)
	self.__slots__ = None  # This is not documented.
	if not name:
		self.__module__ = origin.__module__

其它参数,我们就不了解了,在 Generic 也只传了两个参数,对应这里面的 originparams

class Demo(Generic[T]):
	pass

对于上面代码,换一种写法就是:

class Demo(_GenericAlias(Generic, T)):
	pass

一个类是不是会用到 () 来实例化一个对象,如下:

Demo()

在 Python 中的 () 也是一个语法糖,对应的是 __call__ 方法,所以 Demo() 等同于 Demo.__call__(),本质上就是调用父类的 _GenericAlias(Generic, T).__call__ 方法,所以我们应该去找 _GenericAlias__call__ 方法。

_GenericAlias 没有实现 __call__,而是它继承的父类实现的 _BaseGenericAlias,如下:

def __call__(self, *args, **kwargs):
    if not self._inst:
        raise TypeError(f"Type {self._name} cannot be instantiated; "
                        f"use {self.__origin__.__name__}() instead")
    result = self.__origin__(*args, **kwargs)
    try:
        result.__orig_class__ = self
    except AttributeError:
        pass
    return result

self.__origin__ 就是本例中的 Demo 类,可以看到这里先是进行实例化了,然后再将 self 绑定在了 result 上。注意,这里的 self 指的就是 _GenericAlias 对象。

会到上文讲到的 get_orgs

get_args(self.__orig_class__)[0]

这里获取的 self.__orig_class__ 就是 _GenericAlias 的对象,get_orgs 源码如下:

def get_args(tp):
    """Get type arguments with all substitutions performed.

    For unions, basic simplifications used by Union constructor are performed.
    Examples::
        get_args(Dict[str, int]) == (str, int)
        get_args(int) == ()
        get_args(Union[int, Union[T, int], str][int]) == (int, str)
        get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
        get_args(Callable[[], T][int]) == ([], int)
    """
    if isinstance(tp, _AnnotatedAlias):
        return (tp.__origin__,) + tp.__metadata__
    if isinstance(tp, (_GenericAlias, GenericAlias)):
        res = tp.__args__  # 访问了 __args__
        if _should_unflatten_callable_args(tp, res):
            res = (list(res[:-1]), res[-1])
        return res
    if isinstance(tp, types.UnionType):
        return tp.__args__
    return ()

显而易见,就是访问了 _GenericAlias__args__ 成员。我们再看下 __args__ 是什么?

    def __init__(self, origin, args, *, inst=True, name=None,
                 _paramspec_tvars=False):
        super().__init__(origin, inst=inst, name=name)
        if not isinstance(args, tuple):
            args = (args,)
        self.__args__ = tuple(... if a is _TypingEllipsis else
                              a for a in args)
        self.__parameters__ = _collect_parameters(args)
        self._paramspec_tvars = _paramspec_tvars
        if not name:
            self.__module__ = origin.__module__

__args__ 来自参数传递 args ,让我们回到 Generic.__class_getitem__ 方法,如下:

    @_tp_cache
    def __class_getitem__(cls, params):
        if not isinstance(params, tuple):
            params = (params,)
        # 中间省略大部分内容,都是为了组装 params
        return _GenericAlias(cls, params,
                             _paramspec_tvars=True)

显而易见,就是将 [] 中的泛型参数传了进来,并实例化了 _GenericAlias 对象,并在泛型类实例化时(即调用 __call__ 时)将其绑定在该实例的 __orig_class__ 成员上。

这也解释了为什么只能在非 __init__ 实例方法中访问 __orig_class__,因为泛型类实际上是实例化之后才被绑定的 __orig_class__