译(四十九)-Pytorch计算模型参数量

stackoverflow热门问题目录

如有翻译问题欢迎评论指出,谢谢。

计算Pytorch模型参数量

  • Fábio Perez asked:

    • 怎么计算 Pytorch 模型的参数量?类似 Keras 的 model.count_params() 那样的函数。
  • Answers:

    • Fábio Perez - vote: 198

    • Pytorch 没有类似 Keras 计算参数量的函数,但可以通过每个参数组的求和得出参数量:

    • pytorch_total_params = sum(p.numel() for p in model.parameters())
    • 如果希望只计算可训练的参数:

    • pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    • 参考 Pytorch 论坛回答

    • 注:这个回答是自问自答,如果你们有更好的解决方法,请分享一下。

    • Fábio Perez - vote: 46

    • 为了像 Keras 一样计算每层的参数量,Pytorch 的 model.named_paramters() 能返回一个迭代器,包含参数名与参数本身。

    • 举例说明:

    • from prettytable import PrettyTable
      #
      def count_parameters(model):
        table = PrettyTable(["Modules", "Parameters"])
        total_params = 0
        for name, parameter in model.named_parameters():
            if not parameter.requires_grad: continue
            params = parameter.numel()
            table.add_row([name, params])
            total_params+=params
        print(table)
        print(f"Total Trainable Params: {total_params}")
        return total_params
      #
      count_parameters(net)
    • 输出如下:

    • +-------------------+------------+
      |      Modules      | Parameters |
      +-------------------+------------+
      | embeddings.weight |   922866   |
      |    conv1.weight   |  1048576   |
      |     conv1.bias    |    1024    |
      |     bn1.weight    |    1024    |
      |      bn1.bias     |    1024    |
      |    conv2.weight   |  2097152   |
      |     conv2.bias    |    1024    |
      |     bn2.weight    |    1024    |
      |      bn2.bias     |    1024    |
      |    conv3.weight   |  2097152   |
      |     conv3.bias    |    1024    |
      |     bn3.weight    |    1024    |
      |      bn3.bias     |    1024    |
      |    lin1.weight    |  50331648  |
      |     lin1.bias     |    512     |
      |    lin2.weight    |   265728   |
      |     lin2.bias     |    519     |
      +-------------------+------------+
      Total Trainable Params: 56773369
    • Thong Nguyen - vote: 12

    • 如果希望避免重复计算共享的参数,可以用 torch.Tensor.data_ptr,即:

    • sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
    • 下面是更繁琐的实现,添加了一个选项,用于跳过不被训练的参数:

    • def numel(m: torch.nn.Module, only_trainable: bool = False):
        """
        returns the total number of parameters used by m (only counting
        shared parameters once); if only_trainable is True, then only
        includes parameters with requires_grad = True
        """
        parameters = list(m.parameters())
        if only_trainable:
            parameters = [p for p in parameters if p.requires_grad]
        unique = {p.data_ptr(): p for p in parameters}.values()
        return sum(p.numel() for p in unique)

Check the total number of parameters in a PyTorch model

  • Fábio Perez asked:

    • How to count the total number of parameters in a PyTorch model? Something similar to model.count_params() in Keras.
      怎么计算 Pytorch 模型的参数量?类似 Keras 的 model.count_params() 那样的函数。
  • Answers:

    • Fábio Perez - vote: 198

    • PyTorch doesn\'t have a function to calculate the total number of parameters as Keras does, but it\'s possible to sum the number of elements for every parameter group:
      Pytorch 没有类似 Keras 计算参数量的函数,但可以通过每个参数组的求和得出参数量:

    • pytorch_total_params = sum(p.numel() for p in model.parameters())
    • If you want to calculate only the trainable parameters:
      如果希望只计算可训练的参数:

    • pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    • Answer inspired by this answer on PyTorch Forums.
      参考 Pytorch 论坛回答

    • Note: I\'m answering my own question. If anyone has a better solution, please share with us.
      注:这个回答是自问自答,如果你们有更好的解决方法,请分享一下。

    • Fábio Perez - vote: 46

    • To get the parameter count of each layer like Keras, PyTorch has model.named_paramters() that returns an iterator of both the parameter name and the parameter itself.
      为了像 Keras 一样计算每层的参数量,Pytorch 的 model.named_paramters() 能返回一个迭代器,包含参数名与参数本身。

    • Here is an example:
      举例说明:

    • from prettytable import PrettyTable
      #
      def count_parameters(model):
        table = PrettyTable(["Modules", "Parameters"])
        total_params = 0
        for name, parameter in model.named_parameters():
            if not parameter.requires_grad: continue
            params = parameter.numel()
            table.add_row([name, params])
            total_params+=params
        print(table)
        print(f"Total Trainable Params: {total_params}")
        return total_params
      #
      count_parameters(net)
    • The output would look something like this:
      输出如下:

    • +-------------------+------------+
      |      Modules      | Parameters |
      +-------------------+------------+
      | embeddings.weight |   922866   |
      |    conv1.weight   |  1048576   |
      |     conv1.bias    |    1024    |
      |     bn1.weight    |    1024    |
      |      bn1.bias     |    1024    |
      |    conv2.weight   |  2097152   |
      |     conv2.bias    |    1024    |
      |     bn2.weight    |    1024    |
      |      bn2.bias     |    1024    |
      |    conv3.weight   |  2097152   |
      |     conv3.bias    |    1024    |
      |     bn3.weight    |    1024    |
      |      bn3.bias     |    1024    |
      |    lin1.weight    |  50331648  |
      |     lin1.bias     |    512     |
      |    lin2.weight    |   265728   |
      |     lin2.bias     |    519     |
      +-------------------+------------+
      Total Trainable Params: 56773369
    • Thong Nguyen - vote: 12

    • If you want to avoid double counting shared parameters, you can use torch.Tensor.data_ptr. E.g.:
      如果希望避免重复计算共享的参数,可以用 torch.Tensor.data_ptr,即:

    • sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
    • Here\'s a more verbose implementation that includes an option to filter out non-trainable parameters:
      下面是更繁琐的实现,添加了一个选项,用于跳过不被训练的参数:

    • def numel(m: torch.nn.Module, only_trainable: bool = False):
        """
        returns the total number of parameters used by m (only counting
        shared parameters once); if only_trainable is True, then only
        includes parameters with requires_grad = True
        """
        parameters = list(m.parameters())
        if only_trainable:
            parameters = [p for p in parameters if p.requires_grad]
        unique = {p.data_ptr(): p for p in parameters}.values()
        return sum(p.numel() for p in unique)

版权声明:
作者:MWHLS
链接:https://panwj.top/3662.html
来源:无镣之涯
文章版权归作者所有,未经允许请勿转载。

THE END
分享
二维码
打赏
< <上一篇
下一篇>>