0. 引言
前几天分几篇博文精细地讲述了《von Mises-Fisher 分布》, 以及相应的 PyTorch 实现《von Mises-Fisher Distribution (代码解析)》, 其中以 Uniform 分布为例简要介绍了 torch.distributions
包的用法. 本以为已经可以了, 但这两天看到论文 The Power Spherical distribution 的代码, 又被其实现分布的方式所吸引.
Power Spherical 分布与 von Mises Fisher 分布类似, 只不过将后者概率密度函数中的指数函数换成了多项式函数: f p ( x ; μ , κ ) ∝ e x p ( κ μ ⊺ x ) ⇓ f p ( x ; μ , κ ) ∝ ( 1 + μ ⊺ x ) κ \begin{aligned} f_p(\bm{x}; \bm{\mu}, \kappa) &\propto exp(\kappa \bm{\mu}^\intercal \bm{x}) \\ &\Downarrow\\ f_p(\bm{x}; \bm{\mu}, \kappa) &\propto (1+\bm{\mu}^\intercal \bm{x})^\kappa \\ \end{aligned} fp(x;μ,κ)fp(x;μ,κ)∝exp(κμ⊺x)⇓∝(1+μ⊺x)κ 采样框架基本一致, 且这么做可以使边缘 t t t 的线性变换 t + 1 2 ∼ B e t a ( p − 1 2 + κ , p − 1 2 ) \frac{t+1}{2} \sim Beta(\frac{p-1}{2}+\kappa, \frac{p-1}{2}) 2t+1∼Beta(2p−1+κ,2p−1), 从而避免了接受-拒绝采样过程.
当然, 按照之前的 VonMisesFisher
的写法, 这个 t
的采样大概是这样:
z = beta.sample(sample_shape)
t = 2 * z - 1
但现在我遇到了这种写法:
class MarginalTDistribution(tds.TransformedDistribution):
arg_constraints = {
'dim': constraints.positive_integer,
'scale': constraints.positive,
}
has_rsample = True
def __init__(self, dim, scale, validate_args=None):
self.dim = dim
self.scale = scale
super().__init__(
tds.Beta( # 用 Beta 分布转换, z 服从 Beta(α+κ,β)
(dim - 1) / 2 + scale, (dim - 1) / 2, validate_args=validate_args
),
transforms=tds.AffineTransform(loc=-1, scale=2), # t=2z-1 是想要的边缘分布随机数
)
然后就可以进行对
t
t
t 的采样了.
我们可以看到其基本架构, 本文将详细解析其内部的具体细节, 包括:
1. Distribution
在之前的 <von Mises-Fisher Distribution (代码解析)> 中, 已经通过 Uniform
简单介绍了 Distribution
的用法. 它是实现各种分布的抽象基类. 本文将以解析源码的方式详细介绍.
1.1 参数验证 validate_args
打开源码, 首先映入眼帘的是关于参数验证的代码:
# true if Python was not started with an -O option. See also the assert statement.
_validate_args = __debug__
@staticmethod
def set_default_validate_args(value: bool) -> None:
"""
设置 validation 是否开启.
validation 通常是耗时的, 所以最好在模型 work 后关闭它.
"""
if value not in [True, False]:
raise ValueError
Distribution._validate_args = value
Distribution
有一个类属性叫 _validate_args
, 默认值是 __debug__
(见附录1), 可以通过类静态方法 set_default_validate_args(value: bool)
来修改此值.
构造方法 __init__(...)
中的验证逻辑:
def __init__(self, ..., validate_args: Optional[bool]=None):
...
if validate_args is not None:
self._validate_args = validate_args
也就是说, 你可以在创建 Distribution
实例的时候设置是否进行参数验证. 如果不设置, 则按照类的属性 Distribution._validate_args
来.
if self._validate_args: # validate_args=False 就不用设置 arg_constraints 了
try: # 尝试获取字典 arg_constraints
arg_constraints = self.arg_constraints
except NotImplementedError: # 如果没设置, 则设置为 {}, 抛出警告
arg_constraints = {}
warnings.warn(...)
如果需要验证参数, 那么首先要获取一个叫 arg_constraints
的参数验证字典, 它列出了需要验证哪些参数. 这个抽象类里面并没有给出, 需要用户继承该类时写在子类中. 以 Uniform
为例:
class Uniform(Distribution):
...
arg_constraints = {
"low": constraints.dependent(is_discrete=False, event_dim=0),
"high": constraints.dependent(is_discrete=False, event_dim=0),
}
...
至于 constraints.dependent
是啥, 后面会详细介绍. 值得注意的是, 如果你在创建实例时指定 validate_args=False
, 那么所有关于参数验证的事就都不用管了.
for param, constraint in arg_constraints.items():
if constraints.is_dependent(constraint):
continue # skip constraints that cannot be checked
if param not in self.__dict__ and isinstance(
getattr(type(self), param), lazy_property
):
continue # skip checking lazily-constructed args
value = getattr(self, param) # 从当前对象获取参数 value
valid = constraint.check(value) # 检查参数值
if not valid.all(): # 检查不通过
raise ValueError(...)
这一段就是验证过程了, 包括:
- skip constraints that cannot be checked, 由
constraints.is_dependent(constraint)
判断是否可验证; - skip checking lazily-constructed args, 即参数名不在
self.__dict__
中, 并属于lazy_property
的跳过; - 获得参数, 进行验证;
具体的验证细节将在后面介绍.
1.2 batch_shape
& event_shape
除了 validate_args
参数, __init__(...)
方法中的另外两个参数就是:
def __init__(
self,
batch_shape: torch.Size = torch.Size(),
event_shape: torch.Size = torch.Size(),
):
self._batch_shape = batch_shape
self._event_shape = event_shape
...
这两个参数是啥? 在这个抽象类中, 我们看不到太多信息, 甚至 Uniform
中也只有 batch_shape = self.low.size()
的信息, 大概意思同时进行着一批的均匀分布, 如 low = torch.tensor([0.0, 1.0])
时, batch_shape = torch.Size([2])
, 表示一个二元的均匀分布. 看 MultivariateNormal
, 里面信息量较大:
batch_shape = torch.broadcast_shapes(
covariance_matrix.shape[:-2], # [:-2]是去掉了协方差矩阵的维度, 剩下的可能是 batch 的维度
loc.shape[:-1] # [:-1]是去掉了 envent 的维度, 剩下的可能是 batch 的维度
) # broadcast_shapes 意思是进行了广播, 如果 matrix 的 batch_shape 是 [2,1], loc 的 batch_shape 是 [1,2], 那么整个的 batch_shape 是广播后的 [2,2]
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) # 之后 covariance_matrix 都被 expand 了
...
event_shape = self.loc.shape[-1:] # 看来就是样本的 shape
从这一段来看, batch_shape
是指创建的实例在进行多少个平行的基本分布, 而 event_shape
是指基本分布的事件(支撑点)维度. 如:
locs = torch.randn(2, 3)
matrixs = torch.randn(2, 3, 3)
covariance_matrixs = torch.bmm(matrixs, matrixs.transpose(1, 2))
normal = distributions.MultivariateNormal(loc=locs, covariance_matrix=covariance_matrixs)
print(normal.batch_shape) # 2
print(normal.event_shape) # 3
print(normal.sample())
##### output #####
torch.Size([2])
torch.Size([3])
tensor([[ 1.8972, -0.3961, -0.1530],
[-0.5018, -2.5110, 0.1293]])
batch 的意思还是那个 batch, 不过这里是指分布的 batch, 而不是数据的 batch. 采样时, 得到一批 samples, 对应每个分布.
还有一个 method 和这两个参数有关: expand
, 因为它是一个抽象 method, 基类中并没有实现, 那就直接看 MultivariateNormal
中的:
def expand(self, batch_shape: torch.Size, _instance=None):
"""
Args:
batch_shape (torch.Size): the desired expanded size.
_instance: new instance provided by subclasses that need to override `.expand`.
Returns:
New distribution instance with batch dimensions expanded to `batch_size`.
"""
new = self._get_checked_instance(MultivariateNormal, _instance)
batch_shape = torch.Size(batch_shape)
loc_shape = batch_shape + self.event_shape
cov_shape = batch_shape + self.event_shape + self.event_shape
new.loc = self.loc.expand(loc_shape)
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
if "covariance_matrix" in self.__dict__:
new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
if "scale_tril" in self.__dict__:
new.scale_tril = self.scale_tril.expand(cov_shape)
if "precision_matrix" in self.__dict__:
new.precision_matrix = self.precision_matrix.expand(cov_shape)
super(MultivariateNormal, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
这个 method 会创建一个新的 instance 或调用的时候用户提供, 并设置 batch_shape
为参数提供的形状, 然后把参数 expand
到新的 batch_shape
. 用法:
mean = torch.randn(3)
matrix = torch.randn(3, 3)
covariance_matrix = torch.mm(matrix, matrix.t())
mvn = MultivariateNormal(mean, covariance_matrix)
bmvn = mvn.expand(torch.Size([2]))
print(bmvn.batch_shape)
print(bmvn.event_shape)
print(bmvn.sample())
##### output #####
torch.Size([2])
torch.Size([3])
tensor([[-4.0891, -4.2424, 6.2574],
[ 0.7656, -0.2199, -0.9836]])
1.3 一些属性
包括: m e a n mean mean, m o d e mode mode, s t d std std, v a r i a n c e variance variance, e n t r o p y entropy entropy 等基本属性, 都需要用户在子类中自己实现. 还有一些相关的函数:
- cumulative density/mass function
cdf(value)
; - inverse cumulative density/mass function
icdf(value)
;
这个函数非常有用, Inverse Transform Sampling 中用其进行采样. 从 U ( 0 , 1 ) U(0,1) U(0,1) 中采样一个 u u u, 然后令 x = F − 1 ( u ) x = F^{-1}(u) x=F−1(u) 就是所求随机变量 X X X 的一个采样. - log of the probability density/mass function
log_prob(value)
, 对数概率.
注意, 目前看到的只有 log_prob
, 并没有 prob
, 一些示例要么只算 log_prob
, 要么计算后通过 exp(log_prob)
得到 prob
.
2. constraints.Constraint
前面在1.1参数验证中已经遇到 constraints.dependent(is_discrete=False, event_dim=0)
和 constraint.check(value)
, 但没有讲具体细节. 本节将详细剖析.
2.1 抽象基类 Constraint
先看源码:
class Constraint:
"""
一个 constraint 对象, 表示变量在某区域内有效, 即变量可优化的范围.
"""
is_discrete = False # Default to continuous.
event_dim = 0 # Default to univariate.
def check(self, value):
"""
结果的形状为"sample_shape + batch_shape", 指示 each event 值是否满足此限制.
"""
raise NotImplementedError
这是抽象基类 Constraint
, 比较简单, 只有两个类属性和一个 method check(value)
. is_discrete
表示待验证值是否为离散; 联想前面的 event_shape
, 大概可以知道 event_dim
是指 len(event_shape)
.(不过目前看只是为了验证参数, 还能验证采样的 event?)
2.2 _Dependent()
不被验证
这个基类信息太少, 对我们理解前面的内容毫无用处, 还是直接观察一些子类吧. 从 dependent = _Dependent()
开始, 它是 constraints.py
中定义好的 placeholder(这个倒是可以学一学):
class _Dependent(Constraint): # 看"_", 应该是不希望用户直接创建实例
"""
Placeholder for variables whose support depends on other variables.
These variables obey no simple coordinate-wise constraints.
"""
def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
self._is_discrete = is_discrete
self._event_dim = event_dim
super().__init__()
def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
"""
Support for syntax to customize static attributes::
constraints.dependent(is_discrete=True, event_dim=1)
"""
if is_discrete is NotImplemented: # 未提供就是默认
is_discrete = self._is_discrete
if event_dim is NotImplemented:
event_dim = self._event_dim
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
def check(self, x):
raise ValueError("Cannot determine validity of dependent constraint")
闹了半天, 我们并不能看到 constraints.dependent(is_discrete=False, event_dim=0)
有什么卵用, 只知道 “Cannot determine validity of dependent constraint”, 这也呼应了前面的:
if constraints.is_dependent(constraint):
continue # skip constraints that cannot be checked
也就是说, dependent
类型的限制是不会执行参数验证的. 那这个 _Dependent
到底有何用处? 先不管了.
2.3 _IndependentConstraint
重新解释 event_dim
我们看点复杂的, MultivariateNormal.arg_constraints
:
arg_constraints = {
"loc": constraints.real_vector,
"covariance_matrix": constraints.positive_definite,
"precision_matrix": constraints.positive_definite,
"scale_tril": constraints.lower_cholesky,
}
这些都是 constraints.py
中定义好的实例, 对于大多情况, 这些预定义好的实例已经够用, 但如果需要, 你也可以自定义. 先看 real_vector
:
independent = _IndependentConstraint
real_vector = independent(real, 1)
class _IndependentConstraint(Constraint):
"""
封装一个 constraint, 通过 aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`,
an event is valid 当且仅当它依赖的所有 entries 是 valid 的.
"""
def __init__(self, base_constraint, reinterpreted_batch_ndims):
self.base_constraint = base_constraint
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
super().__init__()
@property
def event_dim(self):
# real.event_dim 是 0, + real_vector(reinterpreted_batch_ndims=1) = 1
return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
def check(self, value):
result = self.base_constraint.check(value) # 首先要符合 base.check
if result.dim() < self.reinterpreted_batch_ndims:
# 给 batch 留够 dim
expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
raise ValueError(
f"Expected value.dim() >= {expected} but got {value.dim()}"
)
result = result.reshape( # 减掉 event
result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)
)
result = result.all(-1) # 减少一个 dim
return result
意思很明了了, real_vector
是依赖于 real
(base_constraint) 的, reinterpreted_batch_ndims=1
是说把原来的 value
重新解释, event_dim
加上 reinterpreted_batch_ndims
, 比如
value = [[1, 2, 3],
[4, 5, 6]]
本来 real
的 event_dim=0
, 验证结果为(sample_shape + batch_shape = (2,2)
):
value = [[True, True, True],
[True, True, True]]
现在重新解释为 event_dim=1
, 验证结果为:
result = result.reshape( # 减掉 event
result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,) # (-1,) 表示新 event 内的所有 entries 展平
)
result = result.all(-1) # 新 event 内的所有 entries 为 True, 则新 event 为 True
================>
value = [True, True]
3. Transform
& _InverseTransform
上一节介绍了 constraints.Constraint
, 明白了在构建 Distribution
实例时进行的参数验证, 以保证用户提供的参数符合要求. 但还留下了一个疑问: Constraint
中的 event_dim
是指 len(event_shape)
, 难道还能验证采样的 event? 再者, check(value)
返回值的形状是 sample_shape + batch_shape
, 进一步说明它是会被用于采样结果检查的. 让我们看一看能否在 Transform
中找到答案.
Transform
& _InverseTransform
是一对互逆的操作, 看一看里面都有什么:
3.1 Attributes
class Transform:
"""
Attributes:
domain (constraints.Constraint):
Transform 的有效输入范围.
codomain (constraints.Constraint):
Transform 的有效输出范围. # 输出是 inverse transform 的输入.
bijective (bool): Transform 是否双射.
即使不是双射, Transforms 也应是弱伪可逆的:
t(t.inv(t(x)) == t(x) and t.inv(t(t.inv(y))) == t.inv(y).
sign (int or Tensor): 对于双射单变量 transforms, +1 or -1
取决于 transform 单调增还是单调减.
"""
bijective = False # 默认 False
domain: constraints.Constraint
codomain: constraints.Constraint
class Transform:
"""
可逆变换的抽象基类, with computable log det jacobians.
Caching 对于计算逆复杂或不稳定的变换非常有用.
子类应该实现 one or both of `_call` or `_inverse`.
如果 `bijective=True`, 则必须实现 `log_abs_det_jacobian`.
Args:
cache_size (int): If one, the latest single value is cached.
Only 0 and 1 are supported.
"""
def __init__(self, cache_size=0):
self._cache_size = cache_size
self._inv = None
if cache_size == 0:
pass # default behavior
elif cache_size == 1:
self._cached_x_y = None, None
else:
raise ValueError("cache_size must be 0 or 1")
super().__init__()
def __getstate__(self):
state = self.__dict__.copy()
state["_inv"] = None
return state
@property
def event_dim(self):
if self.domain.event_dim == self.codomain.event_dim: # 当定义域和值域 event_dim 相同时, 才能简略为 event_dim
return self.domain.event_dim
raise ValueError("Please use either .domain.event_dim or .codomain.event_dim")
@property
def inv(self):
"""
Returns the inverse :class:`Transform` of this transform.
This should satisfy ``t.inv.inv is t``.
"""
inv = None
if self._inv is not None:
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
self._inv = weakref.ref(inv)
return inv
def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
if type(self).__init__ is Transform.__init__:
return type(self)(cache_size=cache_size)
raise NotImplementedError(f"{type(self)}.with_cache is not implemented")
def __call__(self, x):
"""
Computes the transform `x => y`.
"""
if self._cache_size == 0:
return self._call(x)
x_old, y_old = self._cached_x_y
if x is x_old:
return y_old
y = self._call(x)
self._cached_x_y = x, y
return y
def _inv_call(self, y):
"""
Inverts the transform `y => x`.
"""
if self._cache_size == 0:
return self._inverse(y)
x_old, y_old = self._cached_x_y
if y is y_old:
return x_old
x = self._inverse(y)
self._cached_x_y = x, y
return x
def _call(self, x):
"""
Abstract method to compute forward transformation.
"""
raise NotImplementedError
def _inverse(self, y):
"""
Abstract method to compute inverse transformation.
"""
raise NotImplementedError
def log_abs_det_jacobian(self, x, y):
"""
Computes the log det jacobian `log |dy/dx|` given input and output.
"""
raise NotImplementedError
def forward_shape(self, shape):
"""
Infers the shape of the forward computation, given the input shape.
Defaults to preserving shape.
"""
return shape
def inverse_shape(self, shape):
"""
Infers the shapes of the inverse computation, given the output shape.
Defaults to preserving shape.
"""
return shape
附录
1. __debug__
和 assert
(来自 Kimi)
__debug__
是一个内置变量,用于指示 Python 解释器是否处于调试模式。当 Python 以调试模式运行时,__debug__
被设置为 True
;否则,在优化模式下运行时,它被设置为 False
。
__debug__
可以用于条件性地执行调试代码,例如:
if __debug__:
print("Debug mode is on, performing extra checks...")
# 这里可以放一些只在调试模式下运行的代码,比如详细的日志记录
# 或者复杂的验证逻辑
else:
print("Debug mode is off.")
在上面的例子中,如果命令行执行:
python -O myscript.py
##### output #####
Debug mode is off.
------------------------------------------------------
python myscript.py
##### output #####
Debug mode is on, performing extra checks...
assert
语句受 __debug__
影响:
def calculate(a, b):
# 这个 assert 在 __debug__ 为 True 时执行
assert a > 0 and b > 0, "Both inputs must be positive."
# 正常的函数逻辑
return a * b
# 在这里,assert 会检查输入是否为正数
result = calculate(5, 3)
print(result)
# 如果我们改变条件使 assert 失败
# result = calculate(-1, 3) # 这会触发 AssertionError,除非运行时 __debug__ 为 False