Skip to content

Commit

Permalink
updated _LPVV to fix gh issue
Browse files Browse the repository at this point in the history
  • Loading branch information
oguiza committed Feb 11, 2024
1 parent bb69bef commit fa66212
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 94 deletions.
32 changes: 16 additions & 16 deletions nbs/022_tslearner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.523830</td>\n",
" <td>0.266667</td>\n",
" <td>1.407878</td>\n",
" <td>0.300000</td>\n",
" <td>1.464314</td>\n",
" <td>0.233333</td>\n",
" <td>1.400173</td>\n",
" <td>0.166667</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" </tbody>\n",
Expand Down Expand Up @@ -339,8 +339,8 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.563760</td>\n",
" <td>0.166667</td>\n",
" <td>1.438072</td>\n",
" <td>0.200000</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" </tbody>\n",
Expand Down Expand Up @@ -600,10 +600,10 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>209.704529</td>\n",
" <td>13.806342</td>\n",
" <td>207.336456</td>\n",
" <td>13.982669</td>\n",
" <td>221.817291</td>\n",
" <td>14.270400</td>\n",
" <td>209.151230</td>\n",
" <td>14.046944</td>\n",
" <td>00:01</td>\n",
" </tr>\n",
" </tbody>\n",
Expand Down Expand Up @@ -860,10 +860,10 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>4226.109375</td>\n",
" <td>49.230492</td>\n",
" <td>8007.046387</td>\n",
" <td>74.881180</td>\n",
" <td>4114.624023</td>\n",
" <td>48.891418</td>\n",
" <td>7991.095703</td>\n",
" <td>74.791130</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" </tbody>\n",
Expand Down Expand Up @@ -932,9 +932,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"/Users/nacho/notebooks/tsai/nbs/022_tslearner.ipynb saved at 2024-02-11 00:40:14\n",
"/Users/nacho/notebooks/tsai/nbs/022_tslearner.ipynb saved at 2024-02-11 10:55:07\n",
"Correct notebook to script conversion! 😃\n",
"Sunday 11/02/24 00:40:17 CET\n"
"Sunday 11/02/24 10:55:10 CET\n"
]
},
{
Expand Down
139 changes: 73 additions & 66 deletions nbs/076_models.MultiRocketPlus.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,22 @@
"source": [
"#| export\n",
"def _LPVV(o, dim=2):\n",
" \"Longest stretch of positive values (-1, 1)\"\n",
" \"Longest stretch of positive values along a dimension(-1, 1)\"\n",
"\n",
" seq_len = o.shape[dim]\n",
"\n",
" # Convert tensor to binary format (1 for positive values, 0 for non-positive values)\n",
" binary_tensor = (o > 0).float()\n",
"\n",
" # Find the changes in the binary tensor\n",
" diff = torch.cat([torch.ones_like(binary_tensor.narrow(dim, 0, 1)),\n",
" binary_tensor.narrow(dim, 1, binary_tensor.shape[dim]-1) - binary_tensor.narrow(dim, 0, binary_tensor.shape[dim]-1)], dim=dim)\n",
" binary_tensor.narrow(dim, 1, seq_len-1) - binary_tensor.narrow(dim, 0, seq_len-1)], dim=dim)\n",
"\n",
" # Create groups of positive values\n",
" groups = (diff > 0).cumsum(dim)\n",
"\n",
" # Count the number of values in each group\n",
" counts = torch.zeros_like(binary_tensor).scatter_add_(dim, groups * binary_tensor.long(), binary_tensor)\n",
" # Ensure groups are within valid index bounds\n",
" groups = groups * binary_tensor.long()\n",
" valid_groups = groups.where(groups < binary_tensor.size(dim), torch.tensor(0, device=groups.device))\n",
"\n",
" counts = torch.zeros_like(binary_tensor).scatter_add_(dim, valid_groups, binary_tensor)\n",
"\n",
" # The longest stretch of positive values is the maximum count\n",
" longest_stretch = counts.max(dim)[0]\n",
"\n",
" return torch.nan_to_num(2 * (longest_stretch / seq_len) - 1)\n",
Expand Down Expand Up @@ -118,6 +116,15 @@
" return (o_pos).float().mean(dim) * 2 - 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tsai.imports import default_device"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -127,54 +134,54 @@
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[[-0.0924, 0.0842, 0.5685, 0.3900],\n",
" [ 0.2364, 0.3018, -0.0449, 0.2081],\n",
" [ 0.6782, 0.1842, 0.6873, -0.0590],\n",
" [ 0.1263, 0.2636, 0.3605, -0.0281],\n",
" [ 0.5618, 0.3535, 0.5403, -0.1791]],\n",
"tensor([[[[ 0.5644, -0.0509, -0.0390, 0.4091],\n",
" [ 0.0517, -0.1471, 0.6458, 0.5593],\n",
" [ 0.4516, -0.0821, 0.1271, 0.0592],\n",
" [ 0.4151, 0.4376, 0.0763, 0.3780],\n",
" [ 0.2653, -0.1817, 0.0156, 0.4993]],\n",
"\n",
" [[ 0.2201, 0.1868, 0.1791, -0.1343],\n",
" [ 0.3556, -0.1194, -0.2201, 0.4859],\n",
" [ 0.1115, 0.6232, 0.4436, 0.3880],\n",
" [ 0.6350, 0.1362, 0.5869, -0.1968],\n",
" [ 0.0876, 0.4583, 0.0266, 0.3174]],\n",
" [[-0.0779, 0.0858, 0.1982, 0.3224],\n",
" [ 0.1130, 0.0714, -0.1779, 0.5360],\n",
" [-0.1848, -0.2270, -0.0925, -0.1217],\n",
" [ 0.2820, -0.0205, -0.2777, 0.3755],\n",
" [-0.2490, 0.2613, 0.4237, 0.4534]],\n",
"\n",
" [[-0.1895, 0.1921, 0.2437, -0.1854],\n",
" [-0.1534, -0.2986, 0.2977, 0.3019],\n",
" [ 0.4613, 0.4243, 0.0115, 0.2684],\n",
" [-0.0923, 0.2066, 0.4980, 0.6450],\n",
" [-0.0348, -0.0297, 0.5451, 0.1900]]],\n",
" [[-0.0162, 0.6368, 0.0016, 0.1467],\n",
" [ 0.6035, -0.1365, 0.6930, 0.6943],\n",
" [ 0.2790, 0.3818, -0.0731, 0.0167],\n",
" [ 0.6442, 0.3443, 0.4829, -0.0944],\n",
" [ 0.2932, 0.6952, 0.5541, 0.5946]]],\n",
"\n",
"\n",
" [[[ 0.0524, 0.3093, -0.1079, 0.6815],\n",
" [-0.0642, -0.1675, -0.0548, -0.2654],\n",
" [ 0.3172, 0.2939, -0.2412, -0.0502],\n",
" [ 0.1145, -0.0048, 0.0118, 0.1329],\n",
" [ 0.1715, 0.0915, -0.0179, 0.1825]],\n",
" [[[ 0.6757, 0.5740, 0.3071, 0.4400],\n",
" [-0.2344, -0.1056, 0.4773, 0.2432],\n",
" [ 0.2595, -0.1528, -0.0866, 0.6201],\n",
" [ 0.0657, 0.1220, 0.4849, 0.4254],\n",
" [ 0.3399, -0.1609, 0.3465, 0.2389]],\n",
"\n",
" [[ 0.3505, 0.1599, 0.4867, 0.0462],\n",
" [-0.1878, 0.2045, 0.0392, -0.0331],\n",
" [-0.2096, 0.6557, 0.6754, 0.4057],\n",
" [ 0.6317, 0.1402, -0.2868, 0.2319],\n",
" [-0.1239, -0.2330, 0.4047, 0.0263]],\n",
" [[-0.0765, 0.0516, 0.0028, 0.4381],\n",
" [ 0.5212, -0.2781, -0.0896, -0.0301],\n",
" [ 0.6857, 0.3583, 0.5869, 0.3418],\n",
" [ 0.3002, 0.5135, 0.6011, 0.6499],\n",
" [-0.2807, -0.2888, 0.3965, 0.6585]],\n",
"\n",
" [[ 0.3576, 0.6521, 0.6509, 0.0302],\n",
" [ 0.6389, 0.3282, 0.6566, 0.3341],\n",
" [-0.0629, -0.1169, 0.0781, 0.2252],\n",
" [ 0.4982, 0.2185, 0.4328, 0.5555],\n",
" [ 0.3052, 0.0192, 0.6695, -0.2008]]]])\n",
"tensor([[[ 0.6000, 1.0000, 0.2000, -0.2000],\n",
" [ 1.0000, 0.2000, 0.2000, -0.2000],\n",
" [-0.6000, -0.2000, 1.0000, 0.6000]],\n",
" [[-0.1368, 0.6677, 0.1439, 0.1434],\n",
" [-0.1820, 0.1041, -0.1211, 0.6103],\n",
" [ 0.5808, 0.4588, 0.4572, 0.3713],\n",
" [ 0.2389, -0.1392, 0.1371, -0.1570],\n",
" [ 0.2840, 0.1214, -0.0059, 0.5064]]]], device='mps:0')\n",
"tensor([[[ 1.0000, -0.6000, 0.6000, 1.0000],\n",
" [-0.6000, -0.2000, -0.6000, -0.2000],\n",
" [ 0.6000, 0.2000, -0.2000, 0.2000]],\n",
"\n",
" [[ 0.2000, -0.6000, -0.6000, -0.2000],\n",
" [-0.6000, 0.6000, 0.2000, 0.2000],\n",
" [-0.2000, -0.2000, 1.0000, 0.6000]]])\n"
" [[ 0.2000, -0.6000, -0.2000, 1.0000],\n",
" [ 0.2000, -0.2000, 0.2000, 0.2000],\n",
" [ 0.2000, 0.2000, -0.2000, 0.2000]]], device='mps:0')\n"
]
}
],
"source": [
"o = torch.rand(2, 3, 5, 4) - .3\n",
"o = torch.rand(2, 3, 5, 4).to(default_device()) - .3\n",
"print(o)\n",
"\n",
"output = _LPVV(o, dim=2)\n",
Expand All @@ -190,13 +197,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[0.4007, 0.2374, 0.5392, 0.2991],\n",
" [0.2820, 0.3511, 0.3091, 0.3971],\n",
" [0.4613, 0.2744, 0.3192, 0.3513]],\n",
"tensor([[[0.3496, 0.4376, 0.2162, 0.3810],\n",
" [0.1975, 0.1395, 0.3109, 0.4218],\n",
" [0.4550, 0.5145, 0.4329, 0.3631]],\n",
"\n",
" [[0.1639, 0.2316, 0.0118, 0.3323],\n",
" [0.4911, 0.2901, 0.4015, 0.1775],\n",
" [0.4500, 0.3045, 0.4976, 0.2862]]])\n"
" [[0.3352, 0.3480, 0.4040, 0.3935],\n",
" [0.5023, 0.3078, 0.3968, 0.5221],\n",
" [0.3679, 0.3380, 0.2460, 0.4079]]], device='mps:0')\n"
]
}
],
Expand All @@ -214,13 +221,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[ 0.8910, 1.0000, 0.9592, 0.3842],\n",
" [ 1.0000, 0.8432, 0.6978, 0.5650],\n",
" [-0.0094, 0.4297, 1.0000, 0.7668]],\n",
"tensor([[[ 1.0000, -0.0270, 0.9138, 1.0000],\n",
" [-0.1286, 0.2568, 0.0630, 0.8654],\n",
" [ 0.9823, 0.8756, 0.9190, 0.8779]],\n",
"\n",
" [[ 0.8217, 0.6025, -0.9458, 0.5190],\n",
" [ 0.3065, 0.6655, 0.6970, 0.9109],\n",
" [ 0.9325, 0.8248, 1.0000, 0.7015]]])\n"
" [[ 0.7024, 0.2482, 0.8983, 1.0000],\n",
" [ 0.6168, 0.2392, 0.8931, 0.9715],\n",
" [ 0.5517, 0.8133, 0.7065, 0.8244]]], device='mps:0')\n"
]
}
],
Expand All @@ -238,13 +245,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[-0.3959, -0.5251, -0.1553, -0.8672],\n",
" [-0.4361, -0.4860, -0.5935, -0.6560],\n",
" [-1.0035, -0.8021, -0.3616, -0.5121]],\n",
"tensor([[[-0.3007, -1.0097, -0.6697, -0.2381],\n",
" [-1.0466, -0.9316, -0.9705, -0.3738],\n",
" [-0.2786, -0.2314, -0.3366, -0.4569]],\n",
"\n",
" [[-0.7634, -0.7910, -1.1640, -0.7275],\n",
" [-0.8157, -0.6291, -0.4723, -0.7292],\n",
" [-0.3052, -0.5596, -0.0048, -0.6224]]])\n"
" [[-0.5574, -0.8893, -0.3883, -0.2130],\n",
" [-0.5401, -0.8574, -0.4009, -0.1767],\n",
" [-0.6861, -0.5149, -0.7555, -0.4102]]], device='mps:0')\n"
]
}
],
Expand Down Expand Up @@ -614,9 +621,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"/Users/nacho/notebooks/tsai/nbs/076_models.MultiRocketPlus.ipynb saved at 2024-02-11 01:26:06\n",
"/Users/nacho/notebooks/tsai/nbs/076_models.MultiRocketPlus.ipynb saved at 2024-02-11 10:53:13\n",
"Correct notebook to script conversion! 😃\n",
"Sunday 11/02/24 01:26:09 CET\n"
"Sunday 11/02/24 10:53:16 CET\n"
]
},
{
Expand Down
Binary file modified nbs/models/test.pth
Binary file not shown.
22 changes: 10 additions & 12 deletions tsai/models/MultiRocketPlus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,22 @@ def forward(self, x): return x.view(x.size(0), -1)

# %% ../../nbs/076_models.MultiRocketPlus.ipynb 5
def _LPVV(o, dim=2):
"Longest stretch of positive values (-1, 1)"
"Longest stretch of positive values along a dimension(-1, 1)"

seq_len = o.shape[dim]

# Convert tensor to binary format (1 for positive values, 0 for non-positive values)
binary_tensor = (o > 0).float()

# Find the changes in the binary tensor
diff = torch.cat([torch.ones_like(binary_tensor.narrow(dim, 0, 1)),
binary_tensor.narrow(dim, 1, binary_tensor.shape[dim]-1) - binary_tensor.narrow(dim, 0, binary_tensor.shape[dim]-1)], dim=dim)
binary_tensor.narrow(dim, 1, seq_len-1) - binary_tensor.narrow(dim, 0, seq_len-1)], dim=dim)

# Create groups of positive values
groups = (diff > 0).cumsum(dim)

# Count the number of values in each group
counts = torch.zeros_like(binary_tensor).scatter_add_(dim, groups * binary_tensor.long(), binary_tensor)
# Ensure groups are within valid index bounds
groups = groups * binary_tensor.long()
valid_groups = groups.where(groups < binary_tensor.size(dim), torch.tensor(0, device=groups.device))

counts = torch.zeros_like(binary_tensor).scatter_add_(dim, valid_groups, binary_tensor)

# The longest stretch of positive values is the maximum count
longest_stretch = counts.max(dim)[0]

return torch.nan_to_num(2 * (longest_stretch / seq_len) - 1)
Expand Down Expand Up @@ -67,7 +65,7 @@ def _PPV(o_pos, dim=2):
"Proportion of Positive Values (-1, 1)"
return (o_pos).float().mean(dim) * 2 - 1

# %% ../../nbs/076_models.MultiRocketPlus.ipynb 10
# %% ../../nbs/076_models.MultiRocketPlus.ipynb 11
class MultiRocketFeaturesPlus(nn.Module):
fitting = False

Expand Down Expand Up @@ -239,7 +237,7 @@ def get_indices(self, kernel_size, max_num_kernels):
len(indices), max_num_kernels, False))]
return indices, pos_values

# %% ../../nbs/076_models.MultiRocketPlus.ipynb 11
# %% ../../nbs/076_models.MultiRocketPlus.ipynb 12
class MultiRocketBackbonePlus(nn.Module):
def __init__(self, c_in, seq_len, num_features=50_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=None, max_num_kernels=84, use_diff=True):
super(MultiRocketBackbonePlus, self).__init__()
Expand All @@ -266,7 +264,7 @@ def forward(self, x):
output = self.branch_x(x)
return output

# %% ../../nbs/076_models.MultiRocketPlus.ipynb 12
# %% ../../nbs/076_models.MultiRocketPlus.ipynb 13
class MultiRocketPlus(nn.Sequential):

def __init__(self, c_in, c_out, seq_len, d=None, num_features=50_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=None, max_num_kernels=84,
Expand Down

0 comments on commit fa66212

Please sign in to comment.