把上一个阶段的工作收尾之后,终于有时间实现一些平时想实现但比较麻烦就一直懒得弄的功能。事情的起因是这样的:刚刚我正在研究 Human3.6M 数据集的标注,这个数据集的标注是一个类似 JSON 的格式。并且除了一般的 JSON 文件结构外,其中还有一些字段的值是高维 numpy 数组。我试着将其中的两项打印出来,命令行瞬间被一大堆各种各样的数值填满了。尽管 IPython Terminal 的输出有自带的格式化功能,但整个输出的格式还是被巨大的数组打乱了。

实际上我平时经常行地遇到这种问题,一般的解决方法是仿照 mmcv.parallel.collate 中的形式写一个递归打印函数,然后把诸如 np.ndarraytorch.Tensor 以及 mmcv.DataContainer 这种高维数据的 __repr__ 函数映射为一个形式类似 lambda x: f"{x.shape}, {x.dtype}" 的函数。

这种方式虽然能在一定程度上解决高维数据打印的问题,但是由于函数逻辑未经仔细考虑,打印的效果依然不是非常好,甚至在一些 corner case 处会出现十分诡异的效果。为了比较彻底地解决这个问题,我决定基于比较成熟的 Python built-in pprint 来拓展开发,实现这一功能。

我首先考虑的设计方式是引入类似 json.JSONEncoder 的设计方式,也就是在使用 pprint 函数时,添加一个额外的参数,将一个 formatter 类传入,对特定类型对象的序列化方式定义在这个类中,类似这样:

1
2
3
4
5
6
7
class CustomFormatter(DefaultFormatter):
def default(self, obj):
if isinstance(obj, (np.ndarray, torch.Tensor)):
return f"<{obj.__class__.__name__} {obj.dtype} of shape {obj.shape}>"
return DefaultFormatter.default(self, obj)

pprint.pprint(obj, cls=CustomFormatter)

然而观察了一下 pprint 的内部实现之后,我发现要改成这种形式还有一点难度。主要是因为在 pprint 内部,不同对象的序列化并非依赖于一个 formatter,而是通过维护一个字典,并将不同类型对象的序列化方法注册到这个字典中,在遇到相应对象时直接从字典中找到相应方法进行调用即可,类似这种形式:

1
2
3
4
5
6
7
8
9
_dispatch = {}

def _pprint_list(self, object, stream, indent, allowance, context, level):
stream.write('[')
self._format_items(object, stream, indent, allowance + 1,
context, level)
stream.write(']')

_dispatch[list.__repr__] = _pprint_list

如此看来,要想将其拓展到更多类别,只需要把 np.ndarraytorch.Tensor 等类型的 __repr__ 函数也注册到这一 _dispatch 字典中即可。然而此时仍然存在一个问题,这个字典是被所有的 PrettyPrinter 对象共享的,因此如果直接在其中添加键值对,会影响 Python 内置方法的行为。我希望能够在维持 built-in 行为不变的前提下添加这一功能,因此在实际进行实现时,我采用了二级 dispatch 字典,在遇到过长对象时,首先在 _dispatch_override 中查询,然后再在原有的字典中查询,即可实现重载功能。

为了使这一类型具有较好的可拓展性,我也计划开放一个 register 接口,便于用户将更多自定义的类型注册到字典中,以提高更多场景下的可用性。

最后放一个输出的示例(目前还没有适配 mmcv.parallel.DataContainer,仅展示 ndarrayTensor 的效果):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
{'address': {'city': 'New York',
'postalCode': '10021',
'state': 'NY',
'streetAddress': '21 2nd Street'},
'age': 25,
'firstName': 'John',
'lastName': 'Smith',
'nestedDict': {'nestedList': [<Tensor (shape=torch.Size([233]) dtype=torch.float32)>,
<Tensor (shape=torch.Size([114]) dtype=torch.float32)>,
<ndarray (shape=(514,) dtype=float64)>],
'nestedNumpyArray': <ndarray (shape=(123, 456) dtype=float64)>},
'numpyArray': <ndarray (shape=(3, 4, 5, 6) dtype=float64)>,
'phoneNumber': [{'number': '212 555-1234', 'type': 'home'},
{'number': '646 555-4567', 'type': 'fax'}],
'sex': 'male',
'torchTensor': <Tensor (shape=torch.Size([6, 7, 8]) dtype=torch.float32)>}

这一模块目前还在开发和测试阶段,预计会在测试稳定后整合进我的 pypi package 中发布,如果想使用预览版,可以参考这一链接

2023-07-03 更新:前段时间发现了一种更好的实现方式,即基于 prettyprinter 实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import prettyprinter
import torch
import numpy as np

@prettyprinter.register_pretty(torch.Tensor)
def pretty_torchtensor(tensor: torch.Tensor, ctx):
return prettyprinter.pretty_call(ctx,
torch.Tensor,
shape=tensor.shape,
dtype=tensor.dtype,
device=tensor.device)

@prettyprinter.register_pretty(np.ndarray)
def pretty_npndarray(ndarray: np.ndarray, ctx):
return prettyprinter.pretty_call(ctx,
np.ndarray,
shape=ndarray.shape,
dtype=ndarray.dtype)


pprint = prettyprinter.pprint

prettyprinter 的输出示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
{
'numpyArray': numpy.ndarray(
shape=(224, 224, 3),
dtype=dtype('float64')
),
'torchTensor': torch.Tensor(
shape=torch.Size((64, 3, 224, 224)),
dtype=torch.float32,
device=device(type='cpu')
),
'nestedDict': {
'nestedArray': numpy.ndarray(
shape=(224, 224, 3),
dtype=dtype('float64')
),
'nestedTensor': torch.Tensor(
shape=torch.Size((64, 3, 224, 224)),
dtype=torch.float32,
device=device(type='cpu')
)
}
}