A tool to count the FLOPs of PyTorch model.
pip install thop
(now continously intergrated on Github actions)
OR
pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git
Basic usage
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ))
Define the rule for 3rd party module.
class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule here
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ),
custom_ops={YourModule: count_your_model})
Improve the output readability
Call thop.clever_format
to give a better format of the output.
from thop import clever_format
macs, params = clever_format([macs, params], "%.3f")
The implementation are adapted from torchvision
. Following results can be obtained using benchmark/evaluate_famous_models.py.
|
|