|
1 | 1 | module Docs |
2 | 2 | class Pytorch |
3 | 3 | class EntriesFilter < Docs::EntriesFilter |
4 | | - NAME_REPLACEMENTS = { |
5 | | - "Distributed communication package - torch.distributed" => "torch.distributed" |
| 4 | + TYPE_REPLACEMENTS = { |
| 5 | + "torch.Tensor" => "Tensor", |
| 6 | + "torch.nn" => "Neuro Network", |
| 7 | + "Probability distributions - torch.distributions" => "Probability Distributions", |
| 8 | + "torch" => "Torch", |
| 9 | + "Quantization" => "Quantization", |
| 10 | + "torch.optim" => "Optimization", |
| 11 | + "torch.Storage" => "Storage", |
| 12 | + "torch.nn.functional" => "NN Functions", |
| 13 | + "torch.cuda" => "CUDA", |
| 14 | + "Torch Distributed Elastic" => "Distributed Elastic", |
| 15 | + "torch.fx" => "FX", |
| 16 | + "TorchScript" => "Torch Script", |
| 17 | + "torch.onnx" => "ONNX", |
| 18 | + "Distributed communication package - torch.distributed" => "Distributed Communication", |
| 19 | + "Automatic differentiation package - torch.autograd" => "Automatic Differentiation", |
| 20 | + "torch.linalg" => "Linear Algebra", |
| 21 | + "Distributed Checkpoint - torch.distributed.checkpoint" => "Distributed Checkpoint", |
| 22 | + "Distributed RPC Framework" => "Distributed RPC", |
| 23 | + "torch.special" => "SciPy-like Special", |
| 24 | + "torch.package" => "Package", |
| 25 | + "torch.backends" => "Backends", |
| 26 | + "FullyShardedDataParallel" => "Fully Sharded Data Parallel", |
| 27 | + "torch.sparse" => "Sparse Tensors", |
| 28 | + "torch.export" => "Traced Graph Export", |
| 29 | + "torch.fft" => "Discrete Fourier Transforms", |
| 30 | + "torch.utils.data" => "Datasets and Data Loaders", |
| 31 | + "torch.monitor" => "Monitor", |
| 32 | + "Automatic Mixed Precision package - torch.amp" => "Automatic Mixed Precision", |
| 33 | + "torch.utils.tensorboard" => "Tensorboard", |
| 34 | + "torch.profiler" => "Profiler", |
| 35 | + "torch.mps" => "MPS", |
| 36 | + "DDP Communication Hooks" => "DDP Communication Hooks", |
| 37 | + "Benchmark Utils - torch.utils.benchmark" => "Benchmark Utils", |
| 38 | + "torch.nn.init" => "Parameter Initializations", |
| 39 | + "Tensor Parallelism - torch.distributed.tensor.parallel" => "Tensor Parallelism", |
| 40 | + "torch.func" => "JAX-like Function Transforms", |
| 41 | + "Distributed Optimizers" => "Distributed Optimizers", |
| 42 | + "torch.signal" => "SciPy-like Signal", |
| 43 | + "torch.futures" => "Miscellaneous", |
| 44 | + "torch.utils.cpp_extension" => "Miscellaneous", |
| 45 | + "torch.overrides" => "Miscellaneous", |
| 46 | + "Generic Join Context Manager" => "Miscellaneous", |
| 47 | + "torch.hub" => "Miscellaneous", |
| 48 | + "torch.cpu" => "Miscellaneous", |
| 49 | + "torch.random" => "Miscellaneous", |
| 50 | + "torch.compiler" => "Miscellaneous", |
| 51 | + "Pipeline Parallelism" => "Miscellaneous", |
| 52 | + "Named Tensors" => "Miscellaneous", |
| 53 | + "Multiprocessing package - torch.multiprocessing" => "Miscellaneous", |
| 54 | + "torch.utils" => "Miscellaneous", |
| 55 | + "torch.library" => "Miscellaneous", |
| 56 | + "Tensor Attributes" => "Miscellaneous", |
| 57 | + "torch.testing" => "Miscellaneous", |
| 58 | + "torch.nested" => "Miscellaneous", |
| 59 | + "Understanding CUDA Memory Usage" => "Miscellaneous", |
| 60 | + "torch.utils.dlpack" => "Miscellaneous", |
| 61 | + "torch.utils.checkpoint" => "Miscellaneous", |
| 62 | + "torch.__config__" => "Miscellaneous", |
| 63 | + "Type Info" => "Miscellaneous", |
| 64 | + "torch.utils.model_zoo" => "Miscellaneous", |
| 65 | + "torch.utils.mobile_optimizer" => "Miscellaneous", |
| 66 | + "torch._logging" => "Miscellaneous", |
| 67 | + "torch.masked" => "Miscellaneous", |
| 68 | + "torch.utils.bottleneck" => "Miscellaneous" |
6 | 69 | } |
7 | 70 |
|
8 | | - def get_breadcrumbs() |
9 | | - css('.pytorch-breadcrumbs > li').map { |node| node.content.delete_suffix(' >') } |
| 71 | + def get_breadcrumbs |
| 72 | + css('.pytorch-breadcrumbs > li').map { |
| 73 | + |node| node.content.delete_suffix(' >').strip |
| 74 | + }.reject { |item| item.nil? || item.empty? } |
10 | 75 | end |
11 | 76 |
|
12 | 77 | def get_name |
13 | | - # The id of the container `div.section` indicates the page type. |
14 | | - # If the id starts with `module-`, then it's an API reference, |
15 | | - # otherwise it is a note or design doc. |
16 | | - section_id = at_css('.section[id], section[id]')['id'] |
17 | | - if section_id.starts_with? 'module-' |
18 | | - section_id.remove('module-') |
19 | | - else |
20 | | - name = get_breadcrumbs()[1] |
21 | | - NAME_REPLACEMENTS.fetch(name, name) |
22 | | - end |
| 78 | + b = get_breadcrumbs |
| 79 | + b[(b[1] == 'torch' ? 2 : 1)..].join('.') |
23 | 80 | end |
24 | 81 |
|
25 | 82 | def get_type |
26 | | - name |
| 83 | + t = get_breadcrumbs[1] |
| 84 | + TYPE_REPLACEMENTS.fetch(t, t) |
27 | 85 | end |
28 | 86 |
|
29 | 87 | def include_default_entry? |
30 | | - # Only include API references, and ignore notes or design docs |
31 | | - !subpath.start_with? 'generated/' and type.start_with? 'torch' |
| 88 | + # Only include API entries to simplify and unify the list |
| 89 | + return name.start_with?('torch.') |
32 | 90 | end |
33 | 91 |
|
34 | 92 | def additional_entries |
35 | 93 | return [] if root_page? |
36 | 94 |
|
37 | 95 | entries = [] |
38 | | - |
39 | | - css('dt').each do |node| |
40 | | - name = node['id'] |
41 | | - if name == self.name or name == nil |
| 96 | + css('dl').each do |node| |
| 97 | + dt = node.at_css('dt') |
| 98 | + if dt == nil |
| 99 | + next |
| 100 | + end |
| 101 | + id = dt['id'] |
| 102 | + if id == name or id == nil |
42 | 103 | next |
43 | 104 | end |
44 | 105 |
|
45 | | - case node.parent['class'] |
46 | | - when 'method', 'function' |
47 | | - if node.at_css('code').content.starts_with? 'property ' |
48 | | - # this instance method is a property, so treat it as an attribute |
49 | | - entries << [name, node['id']] |
50 | | - else |
51 | | - entries << [name + '()', node['id']] |
52 | | - end |
53 | | - when 'class', 'attribute' |
54 | | - entries << [name, node['id']] |
| 106 | + case node['class'] |
| 107 | + when 'py method', 'py function' |
| 108 | + entries << [id + '()', id] |
| 109 | + when 'py class', 'py attribute', 'py property' |
| 110 | + entries << [id, id] |
| 111 | + when 'footnote brackets', 'field-list simple' |
| 112 | + next |
55 | 113 | end |
56 | 114 | end |
57 | 115 |
|
|
0 commit comments