0. 引言
前几天分几篇博文精细地讲述了《von Mises-Fisher 分布》, 以及相应的 PyTorch 实现《von Mises-Fisher Distribution (代码解析)》, 其中以 Uniform 分布为例简要介绍了 torch.distributions 包的用法. 本以为已经可以了, 但这两天看到论文 The Power Spherical distribution 的代码, 又被其实现分布的方式所吸引.
Power Spherical 分布与 von Mises Fisher 分布类似, 只不过将后者概率密度函数中的指数函数换成了多项式函数: f p ( x ; μ , κ ) ∝ e x p ( κ μ ⊺ x ) ⇓ f p ( x ; μ , κ ) ∝ ( 1 + μ ⊺ x ) κ \begin{aligned} f_p(\bm{x}; \bm{\mu}, \kappa) &\propto exp(\kappa \bm{\mu}^\intercal \bm{x}) \\ &\Downarrow\\ f_p(\bm{x}; \bm{\mu}, \kappa) &\propto (1+\bm{\mu}^\intercal \bm{x})^\kappa \\ \end{aligned} fp(x;μ,κ)fp(x;μ,κ)∝exp(κμ⊺x)⇓∝(1+μ⊺x)κ 采样框架基本一致, 且这么做可以使边缘 t t t 的线性变换 t + 1 2 ∼ B e t a ( p − 1 2 + κ , p − 1 2 ) \frac{t+1}{2} \sim Beta(\frac{p-1}{2}+\kappa, \frac{p-1}{2}) 2t+1∼Beta(2p−1+κ,2p−1), 从而避免了接受-拒绝采样过程.
当然, 按照之前的 VonMisesFisher 的写法, 这个 t 的采样大概是这样:
z = beta.sample(sample_shape)
t = 2 * z - 1
但现在我遇到了这种写法:
class MarginalTDistribution(tds.TransformedDistribution):
arg_constraints = {
'dim': constraints.positive_integer,
'scale': constraints.positive,
}
has_rsample = True
def __init__(self, dim, scale, validate_args=None):
self.dim = dim
self.scale = scale
super().__init__(
tds.Beta( # 用 Beta 分布转换, z 服从 Beta(α+κ,β)
(dim - 1) / 2 + scale, (dim - 1) / 2, validate_args=validate_args
),
transforms=tds.AffineTransform(loc=-1, scale=2), # t=2z-1 是想要的边缘分布随机数
)
然后就可以进行对 t t t 的采样了.
架构大概是这样的: 一个基本分布类 distributions.Beta 和一个转换 transforms.AffineTransform, 输入到 TransformedDistribution 的子类 MarginalTDistribution 中, 通过对一个
B
e
t
a
Beta
Beta 的线性转换, 实现边缘分布
t
t
t.

我们可以看到其基本架构, 本文将详细解析其内部的具体细节, 包括:
1. Distribution
在之前的 <von Mises-Fisher Distribution (代码解析)> 中, 已经通过 Uniform 简单介绍了 Distribution 的用法. 它是实现各种分布的抽象基类. 本文将以解析源码的方式详细介绍.
1.1 参数验证 validate_args
打开源码, 首先映入眼帘的是关于参数验证的代码:
# true if Python was not started with an -O option. See also the assert statement.
_validate_args = __debug__
@staticmethod
def set_default_validate_args(value: bool) -> None:
"""
设置 validation 是否开启.
validation 通常是耗时的, 所以最好在模型 work 后关闭它.
"""
if value not in [True, False]:
raise ValueError
Distribution._validate_args = value
Distribution 有一个类属性叫 _validate_args, 默认值是 __debug__(见附录1), 可以通过类静态方法 set_default_validate_args(value: bool) 来修改此值.
构造方法 __init__(...) 中的验证逻辑:
def __init__(self, ..., validate_args: Optional[bool]=None):
...
if validate_args is not None:
self._validate_args = validate_args
也就是说, 你可以在创建 Distribution 实例的时候设置是否进行参数验证. 如果不设置, 则按照类的属性 Distribution._validate_args 来.
if self._validate_args: # validate_args=False 就不用设置 arg_constraints 了
try: # 尝试获取字典 arg_constraints
arg_constraints = self.arg_constraints
except NotImplementedError: # 如果没设置, 则设置为 {}, 抛出警告
arg_constraints = {}
warnings.warn(...)
如果需要验证参数, 那么首先要获取一个叫 arg_constraints 的参数验证字典, 它列出了需要验证哪些参数. 这个抽象类里面并没有给出, 需要用户继承该类时写在子类中. 以 Uniform 为例:
class Uniform(Distribution):
...
arg_constraints = {
"low": constraints.dependent(is_discrete=False, event_dim=0),
"high": constraints.dependent(is_discrete=False, event_dim=0),
}
...
至于 constraints.dependent 是啥, 后面会详细介绍. 值得注意的是, 如果你在创建实例时指定 validate_args=False, 那么所有关于参数验证的事就都不用管了.
for param, constraint in arg_constraints.items():
if constraints.is_dependent(constraint):
continue # skip constraints that cannot be checked
if param not in self.__dict__ and isinstance(
getattr(type(self), param), lazy_property
):
continue # skip checking lazily-constructed args
value = getattr(self, param) # 从当前对象获取参数 value
valid = constraint.check(value) # 检查参数值
if not valid.all(): # 检查不通过
raise ValueError(...)
这一段就是验证过程了, 包括:
- skip constraints that cannot be checked, 由
constraints.is_dependent(constraint)判断是否可验证; - skip checking lazily-constructed args, 即参数名不在
self.__dict__中, 并属于lazy_property的跳过; - 获得参数, 进行验证;
具体的验证细节将在后面介绍.
1.2 batch_shape & event_shape
除了 validate_args 参数, __init__(...) 方法中的另外两个参数就是:
def __init__(
self,
batch_shape: torch.Size = torch.Size(),
event_shape: torch.Size = torch.Size(),
):
self._batch_shape = batch_shape
self._event_shape = event_shape
...
这两个参数是啥? 在这个抽象类中, 我们看不到太多信息, 甚至 Uniform 中也只有 batch_shape = self.low.size() 的信息, 大概意思同时进行着一批的均匀分布, 如 low = torch.tensor([0.0, 1.0]) 时, batch_shape = torch.Size([2]), 表示一个二元的均匀分布. 看 MultivariateNormal, 里面信息量较大:
batch_shape = torch.broadcast_shapes(
covariance_matrix.shape[:-2], # [:-2]是去掉了协方差矩阵的维度, 剩下的可能是 batch 的维度
loc.shape[:-1] # [:-1]是去掉了 envent 的维度, 剩下的可能是 batch 的维度
) # broadcast_shapes 意思是进行了广播, 如果 matrix 的 batch_shape 是 [2,1], loc 的 batch_shape 是 [1,2], 那么整个的 batch_shape 是广播后的 [2,2]
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) # 之后 covariance_matrix 都被 expand 了
...
event_shape = self.loc.shape[-1:] # 看来就是样本的 shape
从这一段来看, batch_shape 是指创建的实例在进行多少个平行的基本分布, 而 event_shape 是指基本分布的事件(支撑点)维度. 如:
locs = torch.randn(2, 3)
matrixs = torch.randn(2, 3, 3)
covariance_matrixs = torch.bmm(matrixs, matrixs.transpose(1, 2))
normal = distributions.MultivariateNormal(loc=locs, covariance_matrix=covariance_matrixs)
print(normal.batch_shape) # 2
print(normal.event_shape) # 3
print(normal.sample())
##### output #####
torch.Size([2])
torch.Size([3])
tensor([[ 1.8972, -0.3961, -0.1530],
[-0.5018, -2.5110, 0.1293]])
batch 的意思还是那个 batch, 不过这里是指分布的 batch, 而不是数据的 batch. 采样时, 得到一批 samples, 对应每个分布.
还有一个 method 和这两个参数有关: expand, 因为它是一个抽象 method, 基类中并没有实现, 那就直接看 MultivariateNormal 中的:
def expand(self, batch_shape: torch.Size, _instance=None):
"""
Args:
batch_shape (torch.Size): the desired expanded size.
_instance: new instance provided by subclasses that need to override `.expand`.
Returns:
New distribution instance with batch dimensions expanded to `batch_size`.
"""
new = self._get_checked_instance(MultivariateNormal, _instance)
batch_shape = torch.Size(batch_shape)
loc_shape = batch_shape + self.event_shape
cov_shape = batch_shape + self.event_shape + self.event_shape
new.loc = self.loc.expand(loc_shape)
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
if "covariance_matrix" in self.__dict__:
new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
if "scale_tril" in self.__dict__:
new.scale_tril = self.scale_tril.expand(cov_shape)
if "precision_matrix" in self.__dict__:
new.precision_matrix = self.precision_matrix.expand(cov_shape)
super(MultivariateNormal, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
这个 method 会创建一个新的 instance 或调用的时候用户提供, 并设置 batch_shape 为参数提供的形状, 然后把参数 expand 到新的 batch_shape. 用法:
mean = torch.randn(3)
matrix = torch.randn(3, 3)
covariance_matrix = torch.mm(matrix, matrix.t())
mvn = MultivariateNormal(mean, covariance_matrix)
bmvn = mvn.expand(torch.Size([2]))
print(bmvn.batch_shape)
print(bmvn.event_shape)
print(bmvn.sample())
##### output #####
torch.Size([2])
torch.Size([3])
tensor([[-4.0891, -4.2424, 6.2574],
[ 0.7656, -0.2199, -0.9836]])
1.3 一些属性
包括: m e a n mean mean, m o d e mode mode, s t d std std, v a r i a n c e variance variance, e n t r o p y entropy entropy 等基本属性, 都需要用户在子类中自己实现. 还有一些相关的函数:
- cumulative density/mass function
cdf(value); - inverse cumulative density/mass function
icdf(value);
这个函数非常有用, Inverse Transform Sampling 中用其进行采样. 从 U ( 0 , 1 ) U(0,1) U(0,1) 中采样一个 u u u, 然后令 x = F − 1 ( u ) x = F^{-1}(u) x=F−1(u) 就是所求随机变量 X X X 的一个采样. - log of the probability density/mass function
log_prob(value), 对数概率.
注意, 目前看到的只有 log_prob, 并没有 prob, 一些示例要么只算 log_prob, 要么计算后通过 exp(log_prob) 得到 prob.
2. constraints.Constraint
前面在1.1参数验证中已经遇到 constraints.dependent(is_discrete=False, event_dim=0) 和 constraint.check(value), 但没有讲具体细节. 本节将详细剖析.
2.1 抽象基类 Constraint
先看源码:
class Constraint:
"""
一个 constraint 对象, 表示变量在某区域内有效, 即变量可优化的范围.
"""
is_discrete = False # Default to continuous.
event_dim = 0 # Default to univariate.
def check(self, value):
"""
结果的形状为"sample_shape + batch_shape", 指示 each event 值是否满足此限制.
"""
raise NotImplementedError
这是抽象基类 Constraint, 比较简单, 只有两个类属性和一个 method check(value). is_discrete 表示待验证值是否为离散; 联想前面的 event_shape, 大概可以知道 event_dim 是指 len(event_shape).(不过目前看只是为了验证参数, 还能验证采样的 event?)
2.2 _Dependent() 不被验证
这个基类信息太少, 对我们理解前面的内容毫无用处, 还是直接观察一些子类吧. 从 dependent = _Dependent() 开始, 它是 constraints.py 中定义好的 placeholder(这个倒是可以学一学):
class _Dependent(Constraint): # 看"_", 应该是不希望用户直接创建实例
"""
Placeholder for variables whose support depends on other variables.
These variables obey no simple coordinate-wise constraints.
"""
def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
self._is_discrete = is_discrete
self._event_dim = event_dim
super().__init__()
def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
"""
Support for syntax to customize static attributes::
constraints.dependent(is_discrete=True, event_dim=1)
"""
if is_discrete is NotImplemented: # 未提供就是默认
is_discrete = self._is_discrete
if event_dim is NotImplemented:
event_dim = self._event_dim
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
def check(self, x):
raise ValueError("Cannot determine validity of dependent constraint")
闹了半天, 我们并不能看到 constraints.dependent(is_discrete=False, event_dim=0) 有什么卵用, 只知道 “Cannot determine validity of dependent constraint”, 这也呼应了前面的:
if constraints.is_dependent(constraint):
continue # skip constraints that cannot be checked
也就是说, dependent 类型的限制是不会执行参数验证的. 那这个 _Dependent 到底有何用处? 先不管了.
2.3 _IndependentConstraint 重新解释 event_dim
我们看点复杂的, MultivariateNormal.arg_constraints:
arg_constraints = {
"loc": constraints.real_vector,
"covariance_matrix": constraints.positive_definite,
"precision_matrix": constraints.positive_definite,
"scale_tril": constraints.lower_cholesky,
}
这些都是 constraints.py 中定义好的实例, 对于大多情况, 这些预定义好的实例已经够用, 但如果需要, 你也可以自定义. 先看 real_vector:
independent = _IndependentConstraint
real_vector = independent(real, 1)
class _IndependentConstraint(Constraint):
"""
封装一个 constraint, 通过 aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`,
an event is valid 当且仅当它依赖的所有 entries 是 valid 的.
"""
def __init__(self, base_constraint, reinterpreted_batch_ndims):
self.base_constraint = base_constraint
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
super().__init__()
@property
def event_dim(self):
# real.event_dim 是 0, + real_vector(reinterpreted_batch_ndims=1) = 1
return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
def check(self, value):
result = self.base_constraint.check(value) # 首先要符合 base.check
if result.dim() < self.reinterpreted_batch_ndims:
# 给 batch 留够 dim
expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
raise ValueError(
f"Expected value.dim() >= {expected} but got {value.dim()}"
)
result = result.reshape( # 减掉 event
result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)
)
result = result.all(-1) # 减少一个 dim
return result
意思很明了了, real_vector 是依赖于 real(base_constraint) 的, reinterpreted_batch_ndims=1 是说把原来 value 的 batch_dim 重新解释, 分出 n 个给 event_dim: 加上 reinterpreted_batch_ndims, 比如
value = [[1, 2, 3],
[4, 5, 6]]
本来 real 的 event_dim=0, 验证结果为(sample_shape + batch_shape = (2,2)):
value = [[True, True, True],
[True, True, True]]
现在重新解释为 event_dim=1, 验证结果为:
result = result.reshape( # 减掉 event
result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,) # (-1,) 表示新 event 内的所有 entries 展平
)
result = result.all(-1) # 新 event 内的所有 entries 为 True, 则新 event 为 True
================>
value = [True, True]
3. Transform & _InverseTransform
上一节介绍了 constraints.Constraint, 明白了在构建 Distribution 实例时进行的参数验证, 以保证用户提供的参数符合要求. 但还留下了一个疑问: Constraint 中的 event_dim 是指 len(event_shape), 难道还能验证采样的 event? 再者, check(value) 返回值的形状是 sample_shape + batch_shape, 进一步说明它是会被用于采样结果检查的. 让我们看一看能否在 Transform 中找到答案.
Transform & _InverseTransform 是一对互逆的操作, 实现从一个分布到另一个分布的转换. 这很有用, 因为 distributions 包已经实现了很多常见分布和转换, 自由组合威力巨大. 本节将详细介绍它是如何实现对分布的转换的.
[注] 从 _InverseTransform 的_ 来看, 是不需要用户了解它的.
3.1 抽象类 Transform 的基本信息
class Transform:
"""
变换的抽象基类, 子类应该实现 one or both of `_call` or `_inverse`.
如果 `bijective=True`, 则必须实现 `log_abs_det_jacobian`.
Args:
cache_size (int): If one, the latest single value is cached.
Only 0 and 1 are supported.
"""
bijective = False # Transform 是否双射, 默认 False
domain: constraints.Constraint # 有效输入范围
codomain: constraints.Constraint # 有效输出范围
def __init__(self, cache_size=0):
self._cache_size = cache_size
self._inv = None
if cache_size == 0:
pass # default behavior
elif cache_size == 1:
self._cached_x_y = None, None
else:
raise ValueError("cache_size must be 0 or 1")
super().__init__()
果然, Transform 中有 Constraint 的, 分别是 domain 和 codomain, 用于其检查输入输出是否符合要求. 此外, 还有 bijective 和 cache_size 这两个信息, 等一下看后面怎么说.
3.2 AffineTransform
抽象类的基本信息不多, 还是要看一个简单的例子: AffineTransform, 线性变换.
class AffineTransform(Transform):
bijective = True
def __init__(self, loc, scale, event_dim=0, cache_size=0):
super().__init__(cache_size=cache_size)
self.loc = loc
self.scale = scale
self._event_dim = event_dim
线性变换是可逆的, 可以看到它的 bijective = True. 参数是
y
=
l
o
c
+
s
c
a
l
e
×
x
y = loc + scale × x
y = loc + scale × x 中的 loc 和 scale; event_dim 则是用于构建 domain 和 codomain:
@constraints.dependent_property(is_discrete=False)
def domain(self):
if self.event_dim == 0:
return constraints.real
return constraints.independent(constraints.real, self.event_dim)
@constraints.dependent_property(is_discrete=False)
def codomain(self):
if self.event_dim == 0:
return constraints.real
return constraints.independent(constraints.real, self.event_dim)
即, domain 和 codomain 被限制为 event_dim 维向量, 默认是 0, 输入输出皆为标量.
变换过程
def _call(self, x):
"""
Method to compute forward transformation.
"""
return self.loc + self.scale * x
def _inverse(self, y):
"""
Method to compute inverse transformation.
"""
return (y - self.loc) / self.scale
由于是双射, 还要实现:
def log_abs_det_jacobian(self, x, y):
shape = x.shape
scale = self.scale
if isinstance(scale, numbers.Real):
result = torch.full_like(x, math.log(abs(scale)))
else:
result = torch.abs(scale).log()
if self.event_dim:
result_size = result.size()[: -self.event_dim] + (-1,)
result = result.view(result_size).sum(-1)
shape = shape[: -self.event_dim]
return result.expand(shape)
计算结果的形状调整为 x 中除 event_dim 以外的形状, 即 sample_shape + batch_shape. 至于为什么要这么做, 还需要看 TransformedDistribution 中具体的转换流程.
但这里有个问题, 假设 event_dim=1, 输入的 x.shape=(2,3), 而 scale=2.0 和 scale=torch.tensor(2.0) 的计算结果是不一致的:
====================== scale=2.0 ==========================
result = torch.full_like(x, math.log(abs(2.0)))
[[log(2), log(2), log(2)],
[log(2), log(2), log(2)]]
result_size = (2,3)[: -1] + (-1,) = (2,3)
result = [3log(2), 3log(2)].expand([2]) = [3log(2), 3log(2)]
================== scale=tensor(2.0) =======================
result = torch.abs(scale).log() = log(2)
result_size = ()[: -1] + (-1,) = (-1,)
result = log(2).expand([2]) = [log(2), log(2)]
类似的, 只要 scale 是 tensor, 并出现了计算广播, 就会出现这种情况. 不知道会不会造成计算错误, 看了后面的 TransformedDistribution 就能知道. 现在只能暂时不管了.
3.3 TransformedDistribution
3.3.1 基本信息
class TransformedDistribution(Distribution):
"""
Extension of the Distribution class, which applies a sequence of Transforms
to a base distribution.
"""
arg_constraints: Dict[str, constraints.Constraint] = {}
def __init__(self, base_distribution, transforms, validate_args=None):
>>> 单 transfrom 变成 [transfrom], 再检查是否符合 transforms: List[Transform] <<<
它是对 Distribution 的扩展, 对一个 base distribution 实施一连串的 Transforms:
X ~ BaseDistribution
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
log p(Y) = log p(X) + log |det (dX/dY)|
一个简单的例子:
# #################################
# Building a Logistic Distribution
# X ~ Uniform(0, 1)
# f = a + b * logit(X)
# Y ~ f(X) ~ Logistic(a, b)
# #################################
base_distribution = Uniform(0, 1)
transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
logistic = TransformedDistribution(base_distribution, transforms)
其中 l o g i t ( x ) = l o g x 1 − x logit(x) = log\frac{x}{1-x} logit(x)=log1−xx 是 s i g m o i d sigmoid sigmoid 函数的逆.
下面是 TransformedDistribution 的 __init__(...) 内容(省略了开头将单 Transform 转换为列表以及检查类型的代码):
# >>> Reshape base_distribution according to transforms. >>>
# >>> 获取 base_distribution 的 batch_shape 和 event_shape 以及 event_dim >>>
base_shape = base_distribution.batch_shape + base_distribution.event_shape
base_event_dim = len(base_distribution.event_shape) # 的基本 shape
# <<< 获取 base_distribution 的 batch_shape 和 event_shape 以及 event_dim <<<
# 将 transforms 组合成一个 transform
transform = ComposeTransform(self.transforms)
# 先正向传播 shape, 再反向传播 shape, 一来一回 shape 不一致, 说明途中发生了广播
# 具体例子可为: 线性转换中的 [1,2,3] * [[2],[3]], 输入向量输出矩阵(再反向也是矩阵)
forward_shape = transform.forward_shape(base_shape)
expanded_base_shape = transform.inverse_shape(forward_shape)
if base_shape != expanded_base_shape: # 不一致说明发生了广播 (AffineTransform为例)
base_batch_shape = expanded_base_shape[
: len(expanded_base_shape) - base_event_dim
] # 干脆先把 base_distribution 给 expand 了
# 如 base_shape = batch_shape + event_shape = (,) + (,3) = (3,)
# expanded_base_shape = (2,3), 则 base_batch_shape = (2,)
base_distribution = base_distribution.expand(base_batch_shape) # 结果 base_shape = (2,3)
# transform.domain.event_dim 是指所有 transforms 中最大的 domain.event_dim (这个 domain.event_dim 可能就只是为了检查 dim 是否够用)
reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
if reinterpreted_batch_ndims > 0:
base_distribution = Independent( # 但却实实在在地调整了 base_distribution 的 event_dim
base_distribution, reinterpreted_batch_ndims
) # 参考前面讲的 _IndependentConstraint
self.base_dist = base_distribution
这一部分的主旋律是 Reshape base_distribution according to transforms. 也就是说, self.base_dist 被赋予的是调整过的 base_distribution. 主要包括:
- 调整
batch_shape, bybase_distribution.expand(base_batch_shape), 前面讲过expand; - 调整
event_shape, byIndependent, 这个类似前面讲的_IndependentConstraint, 只不过这里是对Distribution操作;
具体过程看注释. 所以, 使用这种方式建立新的 Distribution 时, 要同时注意 base_distribution 和 transforms 的 event_dim, 这对 log_prob 的计算有影响, 且 base_distribution 的 event_dim 可能被更改.
安排好 self.base_dist 后, 开始计算本 TransformedDistribution 的 batch_shape 和 event_shape.
# Compute shapes.
transform_change_in_event_dim = ( # transform 导致的 event_dim 变化
transform.codomain.event_dim - transform.domain.event_dim
)
event_dim = max(
transform.codomain.event_dim, # the transform is coupled
base_event_dim + transform_change_in_event_dim, # the base dist is coupled
)
assert len(forward_shape) >= event_dim
cut = len(forward_shape) - event_dim # forward_shape 劈开
batch_shape = forward_shape[:cut]
event_shape = forward_shape[cut:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
3.3.2 采样
def sample(self, sample_shape=torch.Size()):
with torch.no_grad():
x = self.base_dist.sample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x
def rsample(self, sample_shape=torch.Size()):
x = self.base_dist.rsample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x
3.3.3 log_prob 需要 log_abs_det
def log_prob(self, value):
if self._validate_args: # 验证样本的就在此处了
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms): # 倒着来
x = transform.inv(y) # 逆变换得到 x, 想计算 `log_prob`, 逆变换就得实现.
event_dim += transform.domain.event_dim - transform.codomain.event_dim
log_prob = log_prob - _sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
y = x
log_prob = log_prob + _sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
)
return log_prob
根据 f y ( y ) = f X ( x ) / ∣ d y d x ∣ f_y(y) = f_X(x)/|\frac{dy}{dx}| fy(y)=fX(x)/∣dxdy∣, 比较容易理解 l o g f y ( y ) = l o g f X ( x ) − l o g ∣ d y d x ∣ logf_y(y) = logf_X(x) - log|\frac{dy}{dx}| logfy(y)=logfX(x)−log∣dxdy∣, 那么代码中大概的框架是连续减 l o g ∣ d y d x ∣ log|\frac{dy}{dx}| log∣dxdy∣, 到这没什么问题. 问题就在于为什么要执行:
_sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
把 event_dim - transform.domain.event_dim 个最右侧的维度加起来? 空想不好理解, 必须举个例子, 还是以 AffineTransform 为例:
当 event_dim > transform.domain.event_dim
假设有:
value = [[1, 2, 3],
[4, 5, 6]], event_dim=1
affine = AffineTransform(1, 2, event_dim=0) # transform.domain.event_dim = 0
则计算的 log_abs_det 为:
[[log2, log2, log2],
[log2, log2, log2]]
但按照 event_dim=1, 即基本的 event 单元为 [1, 2, 3] 和 [4, 5, 6], 对应的
l
o
g
∣
d
y
d
x
∣
log|\frac{dy}{dx}|
log∣dxdy∣ 为 [3log2, 3log2], 这就用到了上面说的 _sum_rightmost, 把 event_dim - transform.domain.event_dim 个最右侧的维度加起来.
而在在一连串的 transforms 中, event_dim += transform.domain.event_dim - transform.codomain.event_dim 代表着当前 x 的 event_dim.
event_dim = transform.domain.event_dim 时表示刚刚好, 输入符合 transform.codomain.event_dim 的要求. 有没有可能 event_dim < transform.domain.event_dim? 怕是不能!
4. 实战解析及解惑
在读了 torch.distributions 包的源码之后, 让我们回到论文 The Power Spherical distribution 的代码, 再一看则豁然开朗. 包括边缘变量
t
t
t 和均匀子球的采样
v
\bm{v}
v 的组合操作, 以及组合后的 Householder 变换. 详情请参阅《von Mises-Fisher 分布》.
4.1 边缘变量 t t t 和均匀子球的采样 v \bm{v} v 的组合操作
class _TTransform(tds.Transform):
"""
大概就是 cat(t,v) 的吧, 注意, t 在开头
"""
# 设置为向量还是有必要的, 因为传递过程中的 log_abs_det_jacobian 计算要用到 event_dim
# real.event_dim=0, real_vector.event_dim=1
# 关系到传播中 log_prob 的计算
domain = constraints.real # 输入空间是实数
codomain = constraints.real # 输出空间也是实数
def _call(self, x):
t = x[..., 0].unsqueeze(-1)
v = x[..., 1:]
return torch.cat([t, v * torch.sqrt(torch.clamp(1 - t ** 2, _EPS))], -1)
def _inverse(self, y):
t = y[..., 0].unsqueeze(-1)
v = y[..., 1:]
return torch.cat([t, v / torch.sqrt(torch.clamp(1 - t ** 2, _EPS))], -1)
def log_abs_det_jacobian(self, x, y):
"""
计算变换后的分布的概率密度时有用 fY(y) = fX(x(y))|dx/dy|
:param x: input
:param y: output
:return: the log det jacobian log |dy/dx| given input and output
"""
t = x[..., 0]
# return ((x.shape[-1] - 3) / 2) * torch.log(torch.clamp(1 - t ** 2, _EPS)) # 怎么感觉是 (d-1)/2?
return ((x.shape[-1] - 1) / 2) * torch.log(torch.clamp(1 - t ** 2, _EPS))
_call 和 _inverse 都是对的, 但感觉 log_abs_det_jacobian 有问题. 首先我们从数学上先推导一下这个变换的
d
y
d
x
\frac{dy}{dx}
dxdy. 设
[
t
,
1
−
t
2
v
1
,
⋯
,
1
−
t
2
v
m
−
1
]
=
t
t
r
a
n
s
f
o
r
m
(
[
t
,
v
1
,
⋯
,
v
m
−
1
]
)
[t, \sqrt{1-t^2}v_1, \cdots, \sqrt{1-t^2}v_{m-1}] = ttransform([t, v_1, \cdots, v_{m-1}])
[t,1−t2v1,⋯,1−t2vm−1]=ttransform([t,v1,⋯,vm−1]), 其中
m
m
m 是向量的维度. 那么雅可比矩阵为:
J
=
[
1
0
⋯
0
−
t
v
1
1
−
t
2
1
−
t
2
⋮
0
⋮
⋮
⋱
⋮
−
t
v
m
−
1
1
−
t
2
0
⋯
1
−
t
2
]
J = \begin{bmatrix} 1 & 0 & \cdots & 0 \\ \frac{-tv_1}{\sqrt{1-t^2}} & \sqrt{1-t^2} & \vdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ \frac{-tv_{m-1}}{\sqrt{1-t^2}}& 0 & \cdots & \sqrt{1-t^2} \end{bmatrix}
J=
11−t2−tv1⋮1−t2−tvm−101−t2⋮0⋯⋮⋱⋯00⋮1−t2
则
d
e
t
(
J
)
=
(
1
−
t
2
)
m
−
1
2
det(J) = (1-t^2)^\frac{m-1}{2}
det(J)=(1−t2)2m−1, 那么
l
o
g
∣
d
y
d
x
∣
=
m
−
1
2
l
o
g
(
1
−
t
2
)
log|\frac{dy}{dx}| = \frac{m-1}{2}log(1-t^2)
log∣dxdy∣=2m−1log(1−t2). 我们看一看代码计算结果是什么, 值是刚才计算的值, 形状为 x.shape[:-1]. 那么问题来了, 如果不管 domain.event_dim, 这个计算结果还是对的, 但这里的 domain.event_dim=real.event_dim=0, 而根据 TransformedDistribution 中的 # Compute shapes, 可计算得到 self.event_dim=1, 此时 event_dim - transform.domain.event_dim=1, 你返回这个值后, 计算 log_prob 会再执行:
_sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim
)
从而导致其在 sample_shape[-1] 上相加, 减少一个维度, 这肯定是不对的.
[注] 为什么明明是
l
o
g
∣
d
y
d
x
∣
=
m
−
1
2
l
o
g
(
1
−
t
2
)
log|\frac{dy}{dx}| = \frac{m-1}{2}log(1-t^2)
log∣dxdy∣=2m−1log(1−t2), 代码中却写 (x.shape[-1] - 1) / 2)?
答曰: 不清楚.
4.2 Householder Transform
class _HouseholderRotationTransform(tds.Transform):
"""
完成拼接后, 要进行 HouseholderRotation
"""
domain = constraints.real
codomain = constraints.real
def __init__(self, loc: torch.Tensor):
super().__init__()
e1 = torch.zeros_like(loc) # 继承
e1[..., 0] = 1.0
self.__u = tn_func.normalize(e1 - loc, dim=-1)
def _call(self, x: torch.Tensor):
return x - 2 * (x * self.__u).sum(-1, keepdim=True) * self.__u
def _inverse(self, y: torch.Tensor): # 逆变换是一样的
return y - 2 * (y * self.__u).sum(-1, keepdim=True) * self.__u
def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor):
# h = torch.eye(x.shape[-1], device=x.device) - 2 * torch.outer(self.__u, self.__u)
# torch.log(torch.abs(torch.det(h)))
return 0.0 # 因为 |y|=|x|, 所以 |h|=1; 正交矩阵
具体原理见《householder 变换》. 现在我们关注 log_abs_det_jacobian.
y
=
(
I
−
2
u
u
⊺
)
x
\bm{y} = (I - 2\bm{u}\bm{u}^\intercal) \bm{x}
y=(I−2uu⊺)x, 所以雅可比矩阵为
J
=
I
−
2
u
u
⊺
J = I - 2\bm{u}\bm{u}^\intercal
J=I−2uu⊺,
∣
d
e
t
(
I
−
2
u
u
⊺
)
∣
=
1
|det(I - 2\bm{u}\bm{u}^\intercal)| = 1
∣det(I−2uu⊺)∣=1,
l
o
g
∣
d
e
t
(
J
)
∣
=
0
log|det(J)| = 0
log∣det(J)∣=0.
同样, 值的计算是没有问题的, 只是其返回值 0.0 的 shape 不对, event_dim - transform.domain.event_dim=1, 你返回一个 0.0, TransformedDistribution 中的 log_prob 无法计算.
[注] 正交矩阵的行列式值都为 1 1 1.
4.3 代码的作者自己实现了 log_prob
既然两个转换的 log_abs_det 都有问题, 那为什么还取得了正确的测试结果? 答案是: 作者自己实现了 log_prob 的计算, 而并未使用 TransformedDistribution 中的 log_prob.
def log_prob(self, value):
return self.log_normalizer() + self.scale * torch.log1p(
(self.loc * value).sum(-1)
)
假设我们注释掉这个作者实现的 log_prob, 转而使用 TransformedDistribution 中的 log_prob, 看看会有正确的结果不:
loc = torch.tensor([0.0, 1.0], requires_grad=True)
scale = torch.tensor(4.0, requires_grad=True)
dist = PowerSpherical(loc, scale)
step_size = 0.001
x = torch.arange(0, 2 * math.pi, step_size)
pt = torch.stack((torch.cos(x), torch.sin(x))).t()
y = torch.exp(dist.log_prob(pt)).detach()
print('integal:', y.sum() * step_size)
###################### output #######################
integal: tensor(inf)
竟然没有报错, 只是输出了一个 tensor(inf). 经过检查, 发现忽略了各 Distribution 的 event_shape, 作者竟然都设置成了默认, 即 torch.Size([]), 且各 Transform 的 forward_shape 和 inverse_shape 都保持默认, 那么送进 TransformedDistribution 后, base_distribution 不会被 expand, 且 event_shape=torch.Size([]), 进而 event_dim - transform.domain.event_dim = 0 -0 = 0, _sum_rightmost 从来都不会计算.
即使 _HouseholderRotationTransform.log_abs_det_jacobian 返回值为 0, 经过:
log_prob = log_prob - _sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
链条的广播计算, 也不是问题. 可以正确计算了? 那为什么是 inf? 问题就在于边缘变量 t 的 log_prb 的计算, 当
t
=
1
t=1
t=1 (
x
=
μ
\bm{x}=\bm{\mu}
x=μ) 时, Beta 分布的 log_prob(1)=log(0), 从而出现无穷的情况, 而这本来能乘以该处的均匀子球概率密度
1
0
\frac{1}{0}
01 避免的.
4.4 心得与教训
可以看到, 这种 BaseDistribution + Transform 实现复杂分布的方式提供方便的同时, 也会存在很多问题. 这里总结几点心得与教训:
- 当基础分布和变换都已经存在时, 使用此架构可以快速搭建新的概率分布, 不必考虑太多细节, 比如
log_pdf,cdf以及一些基础属性; - 当需要实现一个复杂分布时, 尽可能拆解成简单分布和变换, 注意要朝着已存在变换的方向分解, 哪怕拆出来的基础分布不在 PyTorch 的包内, 只要能使分布更简单, 就能达到简化的目的;
-
尽量不要自己写
Transform, 因为要实现的东西有点多,_inverse和log_abs_det_jacobian都比较麻烦, 有那个功夫, 也已经把复杂分布的pdf算出来了; -
尽量直接继承
Distribution实现自己的分布, 为不是使用转换链. 因为转换意味着要计算log_abs_det_jacobian, 然后沿着转换链累加, 这比直接计算目标分布的log_prob增加了许多计算, 且承担不稳定的风险. 比如此例子中的接连三个变换, 计算复杂不说, 还拆开了 x = μ \bm{x} = \bm{\mu} x=μ 时的 f ( t = 1 ) = 0 f(t=1) = 0 f(t=1)=0 和 f ( 1 − t 2 v ) = 1 S p − 2 = 1 0 f(\sqrt{1-t^2}\bm{v}) = \frac{1}{S_{p-2}} = \frac{1}{0} f(1−t2v)=Sp−21=01, 加之为了避免 0 0 0 作为分母而加上的 e p s eps eps 偏移, 使得计算失败. - 如果只是为了采样, 则不必考虑
_inverse和log_abs_det_jacobian, 因为它们只出现在log_prob中.
4.5 关于 batch_shape, event_shape 和 Constraint.event_dim.
如果迫不得已需要使用转换链, 我还是建议尽量实事求是地把这三个参数正确地写上, 而不是全部保持为 0. 一则不符合代码逻辑, 二来不定在哪就出错了. 包的作者既然这么设置, 肯定有他的道理.
附录
1. __debug__ 和 assert (来自 Kimi)
__debug__ 是一个内置变量,用于指示 Python 解释器是否处于调试模式。当 Python 以调试模式运行时,__debug__ 被设置为 True;否则,在优化模式下运行时,它被设置为 False。
__debug__ 可以用于条件性地执行调试代码,例如:
if __debug__:
print("Debug mode is on, performing extra checks...")
# 这里可以放一些只在调试模式下运行的代码,比如详细的日志记录
# 或者复杂的验证逻辑
else:
print("Debug mode is off.")
在上面的例子中,如果命令行执行:
python -O myscript.py
##### output #####
Debug mode is off.
------------------------------------------------------
python myscript.py
##### output #####
Debug mode is on, performing extra checks...
assert 语句受 __debug__ 影响:
def calculate(a, b):
# 这个 assert 在 __debug__ 为 True 时执行
assert a > 0 and b > 0, "Both inputs must be positive."
# 正常的函数逻辑
return a * b
# 在这里,assert 会检查输入是否为正数
result = calculate(5, 3)
print(result)
# 如果我们改变条件使 assert 失败
# result = calculate(-1, 3) # 这会触发 AssertionError,除非运行时 __debug__ 为 False
2. t t t 的概率密度函数推导
直接将 t = μ ⊺ x t = \bm{\mu}^\intercal\bm{x} t=μ⊺x 代入 f p ( x ; μ , κ ) f_p(\bm{x}; \bm{\mu}, \kappa) fp(x;μ,κ), 得: f p ( x ; μ , κ ) = C p ( κ ) ( 1 + μ ⊺ x ) κ = C p ( κ ) ( 1 + t ) κ t ∈ [ − 1 , 1 ] = C p ( κ ) ( 1 + c o s θ ) κ θ ∈ [ 0 , π ] \begin{aligned} f_p(\bm{x}; \bm{\mu}, \kappa) &= C_p(\kappa) (1 + \bm{\mu}^\intercal\bm{x})^\kappa & \\ &= C_p(\kappa) (1 + t)^\kappa & t \in [-1, 1] \\ &= C_p(\kappa) (1 + cos\theta)^\kappa & \theta \in [0, \pi] \end{aligned} fp(x;μ,κ)=Cp(κ)(1+μ⊺x)κ=Cp(κ)(1+t)κ=Cp(κ)(1+cosθ)κt∈[−1,1]θ∈[0,π] 注意这是 x \bm{x} x 一点处的概率密度. 沿着 t t t 处的切子球求积分, 以得到 t t t 或 θ \theta θ 处的整个概率密度: ∫ 切子球 f p ( x ; μ , κ ) d s = ∫ 切子球 C p ( κ ) ( 1 + μ ⊺ x ) κ d s = C p ( κ ) ( 1 + t ) κ 2 π p − 1 2 Γ ( p − 1 2 ) ( 1 − t 2 ) p − 2 2 S p − 2 的表面积 ∝ r p − 2 = C p ( κ ) ( 1 + c o s θ ) κ 2 π p − 1 2 Γ ( p − 1 2 ) s i n p − 2 θ \begin{aligned} \int_{切子球} f_p(\bm{x}; \bm{\mu}, \kappa) ds &= \int_{切子球} C_p(\kappa) (1 + \bm{\mu}^\intercal\bm{x})^\kappa ds & \\ &= C_p(\kappa) (1 + t)^\kappa\frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})}(1-t^2)^\frac{p-2}{2} & S^{p-2} 的表面积 \propto r^{p-2} \\ &= C_p(\kappa) (1 + cos\theta)^\kappa \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} sin^{p-2}\theta & \end{aligned} ∫切子球fp(x;μ,κ)ds=∫切子球Cp(κ)(1+μ⊺x)κds=Cp(κ)(1+t)κΓ(2p−1)2π2p−1(1−t2)2p−2=Cp(κ)(1+cosθ)κΓ(2p−1)2π2p−1sinp−2θSp−2的表面积∝rp−2 根据 n-sphere - Wikipedia, 切子球 S p − 2 S^{p-2} Sp−2 的表面积 S p − 2 = 2 π p − 1 2 Γ ( p − 1 2 ) r p − 2 S_{p-2} = \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} r^{p-2} Sp−2=Γ(2p−1)2π2p−1rp−2, 再沿 t t t 或 θ \theta θ 积分: ∫ 0 π C p ( κ ) ( 1 + c o s θ ) κ 2 π p − 1 2 Γ ( p − 1 2 ) s i n p − 2 θ d θ = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ 1 − 1 ( 1 + t ) κ ( 1 − t 2 ) p − 2 2 ( − 1 1 − t 2 d t ) ∵ c o s 0 = 1 , c o s π = − 1 = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ − 1 1 ( 1 + t ) κ ( 1 − t 2 ) p − 3 2 d t = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ − 1 1 ( 1 + t ) κ [ ( 1 + t ) ( 1 − t ) ] p − 3 2 d t = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ − 1 1 ( 1 + t ) p − 1 2 + κ − 1 ( 1 − t ) p − 1 2 − 1 d t = C p ( κ ) 2 π p − 1 2 Γ ( p − 1 2 ) ∫ 0 1 ( 2 z ) p − 1 2 + κ − 1 ( 2 ( 1 − z ) ) p − 1 2 − 1 2 d z t = 2 z − 1 = C p ( κ ) 2 p + κ − 1 π p − 1 2 Γ ( p − 1 2 ) ∫ 0 1 z p − 1 2 + κ − 1 ( 1 − z ) p − 1 2 − 1 d z \begin{aligned} & \int_0^\pi C_p(\kappa) (1 + cos\theta)^\kappa \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} sin^{p-2}\theta d\theta \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{1}^{-1} (1 + t)^\kappa (1-t^2)^\frac{p-2}{2} (\frac{-1}{\sqrt{1-t^2}} dt) & \because cos0=1,~ cos\pi=-1 \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{-1}^{1} (1 + t)^\kappa (1-t^2)^{\frac{p-3}{2}} dt \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{-1}^{1} (1 + t)^\kappa [(1+t)(1-t)]^{\frac{p-3}{2}} dt \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{-1}^{1} (1 + t)^{\frac{p-1}{2}+\kappa-1} (1-t)^{\frac{p-1}{2}-1} dt \\ =& C_p(\kappa) \frac{2\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{0}^{1} (2z)^{\frac{p-1}{2}+\kappa-1} (2(1-z))^{\frac{p-1}{2}-1} 2dz & t=2z-1 \\ =& C_p(\kappa) \frac{2^{p+\kappa-1}\pi^{\frac{p-1}{2}}}{\Gamma(\frac{p-1}{2})} \int_{0}^{1} z^{\frac{p-1}{2}+\kappa-1} (1-z)^{\frac{p-1}{2}-1} dz \end{aligned} ======∫0πCp(κ)(1+cosθ)κΓ(2p−1)2π2p−1sinp−2θdθCp(κ)Γ(2p−1)2π2p−1∫1−1(1+t)κ(1−t2)2p−2(1−t2−1dt)Cp(κ)Γ(2p−1)2π2p−1∫−11(1+t)κ(1−t2)2p−3dtCp(κ)Γ(2p−1)2π2p−1∫−11(1+t)κ[(1+t)(1−t)]2p−3dtCp(κ)Γ(2p−1)2π2p−1∫−11(1+t)2p−1+κ−1(1−t)2p−1−1dtCp(κ)Γ(2p−1)2π2p−1∫01(2z)2p−1+κ−1(2(1−z))2p−1−12dzCp(κ)Γ(2p−1)2p+κ−1π2p−1∫01z2p−1+κ−1(1−z)2p−1−1dz∵cos0=1, cosπ=−1t=2z−1 那么, 将 α = p − 1 2 + κ β = p − 1 2 C p ( κ ) 2 α + β π β Γ ( β ) = Γ ( α + β ) Γ ( α ) Γ ( β ) C p ( κ ) = Γ ( α + β ) 2 α + β π β Γ ( α ) \begin{aligned} \alpha =& \frac{p-1}{2}+\kappa \\ \beta =& \frac{p-1}{2} \\ C_p(\kappa) \frac{2^{\alpha+\beta}\pi^{\beta}}{\Gamma(\beta)} =& \frac{\Gamma(\alpha+\beta)}{\Gamma(\alpha)\Gamma(\beta)} \\ C_p(\kappa) =& \frac{\Gamma(\alpha+\beta)}{2^{\alpha+\beta}\pi^{\beta}\Gamma(\alpha)} \end{aligned} α=β=Cp(κ)Γ(β)2α+βπβ=Cp(κ)=2p−1+κ2p−1Γ(α)Γ(β)Γ(α+β)2α+βπβΓ(α)Γ(α+β)