diff --git a/deeptables/datasets/dsutils.py b/deeptables/datasets/dsutils.py index 45c870b..c4b2801 100644 --- a/deeptables/datasets/dsutils.py +++ b/deeptables/datasets/dsutils.py @@ -41,3 +41,10 @@ def load_bank(): data = pd.read_csv(f'{basedir}/bank-uci.csv') logger.info(f'data shape:{data.shape}') return data + + +def load_movielens(): + logger.info(f'Base dir:{os.getcwd()}') + data = pd.read_csv(f'{basedir}/movielens_sample.txt') + logger.info(f'data shape:{data.shape}') + return data diff --git a/deeptables/datasets/movielens_sample.txt b/deeptables/datasets/movielens_sample.txt new file mode 100644 index 0000000..9ffa148 --- /dev/null +++ b/deeptables/datasets/movielens_sample.txt @@ -0,0 +1,201 @@ +user_id,movie_id,rating,timestamp,title,genres,gender,age,occupation,zip +3299,235,4,968035345,Ed Wood (1994),Comedy|Drama,F,25,4,19119 +3630,3256,3,966536874,Patriot Games (1992),Action|Thriller,M,18,4,77005 +517,105,4,976203603,"Bridges of Madison County, The (1995)",Drama|Romance,F,25,14,55408 +785,2115,3,975430389,Indiana Jones and the Temple of Doom (1984),Action|Adventure,M,18,19,29307 +5848,909,5,957782527,"Apartment, The (1960)",Comedy|Drama,M,50,20,20009 +2996,2799,1,972769867,Problem Child 2 (1991),Comedy,M,18,0,63011 +3087,837,5,969738869,Matilda (1996),Children's|Comedy,F,1,1,90802 +872,3092,5,975273310,Chushingura (1962),Drama,M,50,1,20815 +4094,529,5,966223349,Searching for Bobby Fischer (1993),Drama,M,25,17,49017 +1868,3508,3,974694703,"Outlaw Josey Wales, The (1976)",Western,M,50,11,92346 +2913,1387,5,971769808,Jaws (1975),Action|Horror,F,35,20,98119 +380,3481,5,976316283,High Fidelity (2000),Comedy,M,25,2,92024 +2073,1784,5,974759084,As Good As It Gets (1997),Comedy|Drama,F,18,4,13148 +80,2059,3,977788576,"Parent Trap, The (1998)",Children's|Drama,M,56,1,49327 +3679,2557,1,976298130,I Stand Alone (Seul contre tous) (1998),Drama,M,25,4,68108 +2077,788,3,980013556,"Nutty Professor, The (1996)",Comedy|Fantasy|Romance|Sci-Fi,M,18,0,55112 +6036,2085,4,956716684,101 Dalmatians (1961),Animation|Children's,F,25,15,32603 +3675,532,3,966363610,Serial Mom (1994),Comedy|Crime|Horror,M,35,7,06680 +4566,3683,4,964489599,Blood Simple (1984),Drama|Film-Noir,M,35,17,19473 +2996,3763,3,972413564,F/X (1986),Action|Crime|Thriller,M,18,0,63011 +5831,2458,1,957898337,Armed and Dangerous (1986),Comedy|Crime,M,25,1,92120 +1869,1244,2,974695654,Manhattan (1979),Comedy|Drama|Romance,M,45,14,95148 +5389,2657,3,960328279,"Rocky Horror Picture Show, The (1975)",Comedy|Horror|Musical|Sci-Fi,M,45,7,01905 +1391,1535,3,974851275,Love! Valour! Compassion! (1997),Drama|Romance,M,35,15,20723 +3123,2407,3,969324381,Cocoon (1985),Comedy|Sci-Fi,M,25,2,90401 +4694,159,3,963602574,Clockers (1995),Drama,M,56,7,40505 +1680,1988,3,974709821,Hello Mary Lou: Prom Night II (1987),Horror,M,25,20,95380 +2002,1945,4,974677761,On the Waterfront (1954),Crime|Drama,F,56,13,02136-1522 +3430,2690,4,979949863,"Ideal Husband, An (1999)",Comedy,F,45,1,15208 +425,471,4,976284972,"Hudsucker Proxy, The (1994)",Comedy|Romance,M,25,12,55303 +1841,2289,2,974699637,"Player, The (1992)",Comedy|Drama,M,18,0,95037 +4964,2348,4,962619587,Sid and Nancy (1986),Drama,M,35,0,94110 +4520,2160,4,964883648,Rosemary's Baby (1968),Horror|Thriller,M,25,4,45810 +1265,2396,4,1011716691,Shakespeare in Love (1998),Comedy|Romance,F,18,20,49321 +2496,1278,5,974435324,Young Frankenstein (1974),Comedy|Horror,M,50,1,37932 +5511,2174,4,959787754,Beetlejuice (1988),Comedy|Fantasy,M,45,1,92407 +621,833,1,975799925,High School High (1996),Comedy,M,18,4,93560 +3045,2762,5,970189524,"Sixth Sense, The (1999)",Thriller,M,45,1,90631 +2050,2546,4,975522689,"Deep End of the Ocean, The (1999)",Drama,F,35,3,99504 +613,32,4,975812238,Twelve Monkeys (1995),Drama|Sci-Fi,M,35,20,10562 +366,1077,5,978471241,Sleeper (1973),Comedy|Sci-Fi,M,50,15,55126 +5108,367,4,962338215,"Mask, The (1994)",Comedy|Crime|Fantasy,F,25,9,93940 +4502,1960,4,965094644,"Last Emperor, The (1987)",Drama|War,M,50,0,01379 +5512,1801,5,959713840,"Man in the Iron Mask, The (1998)",Action|Drama|Romance,F,25,17,01701 +1861,2642,2,974699627,Superman III (1983),Action|Adventure|Sci-Fi,M,50,16,92129 +1667,1240,4,975016698,"Terminator, The (1984)",Action|Sci-Fi|Thriller,M,50,16,98516 +753,434,3,975460449,Cliffhanger (1993),Action|Adventure|Crime,M,1,10,42754 +1836,2736,5,974826228,Brighton Beach Memoirs (1986),Comedy,M,25,0,10016 +5626,474,5,959052158,In the Line of Fire (1993),Action|Thriller,M,56,16,32043 +1601,1396,4,978576948,Sneakers (1992),Crime|Drama|Sci-Fi,M,25,12,83001 +4725,1100,4,963369546,Days of Thunder (1990),Action|Romance,M,35,5,96707-1321 +2837,2396,5,972571456,Shakespeare in Love (1998),Comedy|Romance,M,18,0,49506 +1776,3882,4,1001558470,Bring It On (2000),Comedy,M,25,0,45801 +2820,457,2,972662398,"Fugitive, The (1993)",Action|Thriller,F,35,0,02138 +1834,2288,3,1038179198,"Thing, The (1982)",Action|Horror|Sci-Fi|Thriller,M,35,5,10990 +284,2716,4,976570902,Ghostbusters (1984),Comedy|Horror,M,25,12,91910 +2744,588,1,973215985,Aladdin (1992),Animation|Children's|Comedy|Musical,M,18,17,53818 +881,4,2,975264028,Waiting to Exhale (1995),Comedy|Drama,M,18,14,76401 +2211,916,3,974607067,Roman Holiday (1953),Comedy|Romance,M,45,6,01950 +2271,2671,4,1007158806,Notting Hill (1999),Comedy|Romance,M,50,14,13210 +1010,2953,1,975222613,Home Alone 2: Lost in New York (1992),Children's|Comedy,M,25,0,10310 +1589,2594,4,974735454,Open Your Eyes (Abre los ojos) (1997),Drama|Romance|Sci-Fi,M,25,0,95136 +1724,597,5,976441106,Pretty Woman (1990),Comedy|Romance,M,18,4,00961 +2590,2097,3,973840056,Something Wicked This Way Comes (1983),Children's|Horror,M,18,4,94044 +1717,1352,3,1009256707,Albino Alligator (1996),Crime|Thriller,F,50,6,30307 +1391,3160,2,974850796,Magnolia (1999),Drama,M,35,15,20723 +1941,1263,3,974954220,"Deer Hunter, The (1978)",Drama|War,M,35,17,94550 +3526,2867,4,966906064,Fright Night (1985),Comedy|Horror,M,35,2,62263-3004 +5767,198,3,958192148,Strange Days (1995),Action|Crime|Sci-Fi,M,25,2,75287 +5355,590,4,960596927,Dances with Wolves (1990),Adventure|Drama|Western,M,56,0,78232 +5788,156,4,958108785,Blue in the Face (1995),Comedy,M,25,0,92646 +1078,1307,4,974938851,When Harry Met Sally... (1989),Comedy|Romance,F,45,9,95661 +3808,61,2,965973222,Eye for an Eye (1996),Drama|Thriller,M,25,7,60010 +974,3897,4,975106398,Almost Famous (2000),Comedy|Drama,M,35,19,94930 +5153,1290,4,961972292,Some Kind of Wonderful (1987),Drama|Romance,M,25,7,60046 +5732,2115,3,958434069,Indiana Jones and the Temple of Doom (1984),Action|Adventure,F,25,11,02111 +4627,2478,3,964110136,Three Amigos! (1986),Comedy|Western,M,56,1,45224 +1884,1831,2,975648062,Lost in Space (1998),Action|Sci-Fi|Thriller,M,45,20,93108 +4284,517,4,965277546,Rising Sun (1993),Action|Drama|Mystery,M,50,7,40601 +1383,468,2,975979732,"Englishman Who Went Up a Hill, But Came Down a Mountain, The (1995)",Comedy|Romance,F,25,7,19806 +2230,2873,3,974599097,Lulu on the Bridge (1998),Drama|Mystery|Romance,F,45,1,60302 +2533,2266,4,974055724,"Butcher's Wife, The (1991)",Comedy|Romance,F,25,3,49423 +6040,3224,5,956716750,Woman in the Dunes (Suna no onna) (1964),Drama,M,25,6,11106 +4384,2918,5,965171739,Ferris Bueller's Day Off (1986),Comedy,M,25,0,43623 +5156,3688,3,961946487,Porky's (1981),Comedy,M,18,14,10024 +615,296,3,975805801,Pulp Fiction (1994),Crime|Drama,M,50,17,32951 +2753,3045,3,973198964,Peter's Friends (1992),Comedy|Drama,F,50,20,27516 +2438,1125,5,974259943,"Return of the Pink Panther, The (1974)",Comedy,M,35,1,22903 +5746,1242,4,958354460,Glory (1989),Action|Drama|War,M,18,15,94061 +5157,3462,5,961944604,Modern Times (1936),Comedy,M,35,1,74012 +3402,1252,5,967433929,Chinatown (1974),Film-Noir|Mystery|Thriller,M,35,20,30306 +76,593,5,977847255,"Silence of the Lambs, The (1991)",Drama|Thriller,M,35,7,55413 +2067,1019,3,974658834,"20,000 Leagues Under the Sea (1954)",Adventure|Children's|Fantasy|Sci-Fi,M,50,16,06430 +2181,2020,3,979353437,Dangerous Liaisons (1988),Drama|Romance,M,25,0,45245 +3947,593,5,965691680,"Silence of the Lambs, The (1991)",Drama|Thriller,M,25,0,90019 +546,218,4,976069421,Boys on the Side (1995),Comedy|Drama,F,25,0,37211 +1246,3030,5,1032056405,Yojimbo (1961),Comedy|Drama|Western,M,18,4,98225 +4214,3186,5,965319143,"Girl, Interrupted (1999)",Drama,F,25,0,20121 +2841,680,3,982805796,Alphaville (1965),Sci-Fi,M,50,12,98056 +4205,3175,4,965321085,Galaxy Quest (1999),Adventure|Comedy|Sci-Fi,F,25,15,87801 +1120,1097,4,974911354,E.T. the Extra-Terrestrial (1982),Children's|Drama|Fantasy|Sci-Fi,M,18,4,95616 +5371,3194,3,960481000,"Way We Were, The (1973)",Drama,M,25,11,55408 +2695,1278,5,973310827,Young Frankenstein (1974),Comedy|Horror,M,35,11,46033 +3312,520,2,976673070,Robin Hood: Men in Tights (1993),Comedy,F,18,4,90039 +5039,1792,1,962513044,U.S. Marshalls (1998),Action|Thriller,F,35,4,97068 +4655,2146,3,963903103,St. Elmo's Fire (1985),Drama|Romance,F,25,1,92037 +3558,1580,5,966802528,Men in Black (1997),Action|Adventure|Comedy|Sci-Fi,M,18,17,66044 +506,3354,1,976208080,Mission to Mars (2000),Sci-Fi,M,25,16,55103-1006 +3568,1230,3,966745594,Annie Hall (1977),Comedy|Romance,M,25,0,98503 +2943,1197,5,971319983,"Princess Bride, The (1987)",Action|Adventure|Comedy|Romance,M,35,12,95864 +716,737,3,982881364,Barb Wire (1996),Action|Sci-Fi,M,18,4,98188 +5964,454,3,956999469,"Firm, The (1993)",Drama|Thriller,M,18,5,97202 +4802,1208,4,996034747,Apocalypse Now (1979),Drama|War,M,56,1,40601 +1106,3624,4,974920622,Shanghai Noon (2000),Action,M,18,4,90241 +3410,2565,3,967419652,"King and I, The (1956)",Musical,M,35,1,20653 +1273,3095,5,974814536,"Grapes of Wrath, The (1940)",Drama,M,35,2,19123 +1706,1916,4,974709448,Buffalo 66 (1998),Action|Comedy|Drama,M,25,20,19134 +4889,590,5,962909224,Dances with Wolves (1990),Adventure|Drama|Western,M,18,4,63108 +4966,2100,3,962609782,Splash (1984),Comedy|Fantasy|Romance,M,50,14,55407 +4238,1884,4,965343416,Fear and Loathing in Las Vegas (1998),Comedy|Drama,M,35,16,44691 +5365,1042,3,960502974,That Thing You Do! (1996),Comedy,M,18,12,90250 +415,1302,3,977501743,Field of Dreams (1989),Drama,F,35,0,55406 +4658,1009,5,963966553,Escape to Witch Mountain (1975),Adventure|Children's|Fantasy,M,25,4,99163 +854,345,3,975357801,"Adventures of Priscilla, Queen of the Desert, The (1994)",Comedy|Drama,F,25,16,44092 +2857,436,4,972509362,Color of Night (1994),Drama|Thriller,M,25,0,10469 +1835,1330,4,974878241,April Fool's Day (1986),Comedy|Horror,M,25,19,11501 +1321,2240,3,974778494,My Bodyguard (1980),Drama,F,25,14,34639 +3274,3698,2,979767184,"Running Man, The (1987)",Action|Adventure|Sci-Fi,M,25,20,02062 +5893,2144,3,957470619,Sixteen Candles (1984),Comedy,M,25,7,02139 +3436,2724,3,967328026,Runaway Bride (1999),Comedy|Romance,M,35,0,98503 +3315,2918,5,967942960,Ferris Bueller's Day Off (1986),Comedy,M,25,12,78731 +5056,2700,5,962488280,"South Park: Bigger, Longer and Uncut (1999)",Animation|Comedy,M,45,1,16673 +5256,208,2,961271616,Waterworld (1995),Action|Adventure,M,25,16,30269 +4290,1193,4,965274348,One Flew Over the Cuckoo's Nest (1975),Drama,M,25,17,98661 +1010,1379,2,975220259,Young Guns II (1990),Action|Comedy|Western,M,25,0,10310 +829,904,4,975368038,Rear Window (1954),Mystery|Thriller,M,1,19,53711 +5953,480,4,957143581,Jurassic Park (1993),Action|Adventure|Sci-Fi,M,1,10,21030 +4732,3016,4,963332896,Creepshow (1982),Horror,M,25,14,24450 +4815,3181,5,972240802,Titus (1999),Drama,F,50,18,04849 +1164,1894,2,1004486985,Six Days Seven Nights (1998),Adventure|Comedy|Romance,F,25,19,90020 +4373,3167,5,965180829,Carnal Knowledge (1971),Drama,M,50,12,32920 +5293,1374,4,961055887,Star Trek: The Wrath of Khan (1982),Action|Adventure|Sci-Fi,M,25,12,95030 +1579,3101,4,981272057,Fatal Attraction (1987),Thriller,M,25,0,60201 +2600,3147,5,973804787,"Green Mile, The (1999)",Drama|Thriller,M,25,14,19312 +1283,480,4,974793389,Jurassic Park (1993),Action|Adventure|Sci-Fi,F,18,1,94607 +3242,3062,5,968341175,"Longest Day, The (1962)",Action|Drama|War,M,50,13,94089 +3618,3374,3,967116272,Daughters of the Dust (1992),Drama,M,56,17,22657 +3762,1337,4,966434517,"Body Snatcher, The (1945)",Horror,M,50,6,11746 +1015,1184,3,975018699,Mediterraneo (1991),Comedy|War,M,35,3,11220 +4645,2344,5,963976808,Runaway Train (1985),Action|Adventure|Drama|Thriller,F,50,6,48094 +3184,1397,4,968709039,Bastard Out of Carolina (1996),Drama,F,25,18,21214 +1285,1794,4,974833328,Love and Death on Long Island (1997),Comedy|Drama,M,35,4,98125 +5521,3354,2,959833154,Mission to Mars (2000),Sci-Fi,F,25,6,02118 +1472,2278,3,974767792,Ronin (1998),Action|Crime|Thriller,M,25,7,90248 +5630,21,4,980085414,Get Shorty (1995),Action|Comedy|Drama,M,35,17,06854 +3710,3033,5,966272980,Spaceballs (1987),Comedy|Sci-Fi,M,1,10,02818 +192,761,1,977028390,"Phantom, The (1996)",Adventure,M,18,1,10977 +1285,1198,5,974880310,Raiders of the Lost Ark (1981),Action|Adventure,M,35,4,98125 +2174,1046,4,974613044,Beautiful Thing (1996),Drama|Romance,M,50,12,87505 +635,1270,4,975768106,Back to the Future (1985),Comedy|Sci-Fi,M,56,17,33785 +910,412,5,975207742,"Age of Innocence, The (1993)",Drama,F,50,0,98226 +1752,2021,4,975729332,Dune (1984),Fantasy|Sci-Fi,M,25,3,96813 +1408,198,4,974762924,Strange Days (1995),Action|Crime|Sci-Fi,M,25,0,90046 +4738,1242,4,963279051,Glory (1989),Action|Drama|War,M,56,1,23608 +1503,1971,2,974748897,"Nightmare on Elm Street 4: The Dream Master, A (1988)",Horror,M,25,12,92688 +3053,1296,3,970601837,"Room with a View, A (1986)",Drama|Romance,F,25,3,55102 +3471,3614,2,973297828,Honeymoon in Vegas (1992),Comedy|Romance,M,18,4,80302 +678,1972,3,988638700,"Nightmare on Elm Street 5: The Dream Child, A (1989)",Horror,M,25,0,34952 +3483,2561,3,986327282,True Crime (1999),Crime|Thriller,F,45,7,30260 +3910,3108,5,965756244,"Fisher King, The (1991)",Comedy|Drama|Romance,M,25,20,91505 +182,1089,1,977085647,Reservoir Dogs (1992),Crime|Thriller,M,18,4,03052 +1755,1653,3,1036917836,Gattaca (1997),Drama|Sci-Fi|Thriller,F,18,4,77005 +3589,70,2,966658567,From Dusk Till Dawn (1996),Action|Comedy|Crime|Horror|Thriller,F,45,0,80010 +471,3481,4,976222483,High Fidelity (2000),Comedy,M,35,7,08904 +1141,813,2,974878678,Larger Than Life (1996),Comedy,F,25,3,84770 +5227,1196,2,961476022,Star Wars: Episode V - The Empire Strikes Back (1980),Action|Adventure|Drama|Sci-Fi|War,M,18,10,64050 +1303,2344,2,974837844,Runaway Train (1985),Action|Adventure|Drama|Thriller,M,25,19,94111 +5080,3102,5,962412804,Jagged Edge (1985),Thriller,F,50,12,95472 +2023,1012,4,1006290836,Old Yeller (1957),Children's|Drama,M,18,4,56001 +3759,2151,5,966094413,"Gods Must Be Crazy II, The (1989)",Comedy,M,35,6,54751 +1685,2664,2,974709721,Invasion of the Body Snatchers (1956),Horror|Sci-Fi,M,35,12,95833 +4715,1221,4,963508830,"Godfather: Part II, The (1974)",Action|Crime|Drama,M,25,2,97205 +1591,350,5,974742941,"Client, The (1994)",Drama|Mystery|Thriller,M,50,7,26501 +4227,3635,3,965411938,"Spy Who Loved Me, The (1977)",Action,M,25,19,11414-2520 +1908,36,5,974697744,Dead Man Walking (1995),Drama,M,56,13,95129 +5365,1892,4,960503255,"Perfect Murder, A (1998)",Mystery|Thriller,M,18,12,90250 +1579,2420,4,981272235,"Karate Kid, The (1984)",Drama,M,25,0,60201 +1866,3948,5,974753321,Meet the Parents (2000),Comedy,M,25,7,94043 +4238,3543,4,965415533,Diner (1982),Comedy|Drama,M,35,16,44691 +3590,2000,5,966657892,Lethal Weapon (1987),Action|Comedy|Crime|Drama,F,18,15,02115 +3401,3256,5,980115327,Patriot Games (1992),Action|Thriller,M,35,7,76109 +3705,540,2,966287116,Sliver (1993),Thriller,M,45,7,30076 +4973,1246,3,962607149,Dead Poets Society (1989),Drama,F,56,2,949702 +4947,380,4,962651180,True Lies (1994),Action|Adventure|Comedy|Romance,M,35,17,90035 +2346,1416,4,974413811,Evita (1996),Drama|Musical,F,1,10,48105 +1427,3596,3,974840560,Screwed (2000),Comedy,M,25,12,21401 +3868,1626,3,965855033,Fire Down Below (1997),Action|Drama|Thriller,M,18,12,73112 +249,2369,3,976730191,Desperately Seeking Susan (1985),Comedy|Romance,F,18,14,48126 +5720,349,4,958503395,Clear and Present Danger (1994),Action|Adventure|Thriller,M,25,0,60610 +877,1485,3,975270899,Liar Liar (1997),Comedy,M,25,0,90631 diff --git a/deeptables/models/config.py b/deeptables/models/config.py index 37ebb35..f152880 100644 --- a/deeptables/models/config.py +++ b/deeptables/models/config.py @@ -49,6 +49,7 @@ class ModelConfig(collections.namedtuple('ModelConfig', 'earlystopping_patience', 'gpu_usage_strategy', 'distribute_strategy', + 'var_len_categorical_columns', ])): def __hash__(self): return self.name.__hash__() @@ -126,7 +127,22 @@ def __new__(cls, earlystopping_patience=1, gpu_usage_strategy=consts.GPU_USAGE_STRATEGY_GROWTH, distribute_strategy=None, + var_len_categorical_columns=None, # a tuple3, format is (column_name, separator, pool_strategy), pool_strategy is one of max,avg; e.g. [('genres', '|', 'avg' )] ): + + if var_len_categorical_columns is not None and len(var_len_categorical_columns) > 0: + # check items + for v in var_len_categorical_columns: + _name = v[0] + if not isinstance(v, (tuple, list)) or len(v) != 3: + raise ValueError("Var len column config should be a tuple 3.") + if exclude_columns is not None: + if _name in exclude_columns: + raise ValueError(f"Var len column {_name} can not put in 'exclude_columns' ") + if categorical_columns is not None and isinstance(categorical_columns, list): + if _name in categorical_columns: + raise ValueError(f"Var len column {_name} can not put in 'categorical_columns' ") + nets = deepnets.get_nets(nets) if home_dir is None and os.environ.get(consts.ENV_DEEPTABLES_HOME) is not None: @@ -175,6 +191,7 @@ def __new__(cls, earlystopping_patience, gpu_usage_strategy, distribute_strategy, + var_len_categorical_columns, ) @property def first_metric_name(self): diff --git a/deeptables/models/deepmodel.py b/deeptables/models/deepmodel.py index f414f19..1a7c447 100644 --- a/deeptables/models/deepmodel.py +++ b/deeptables/models/deepmodel.py @@ -1,5 +1,5 @@ # -*- coding:utf-8 -*- - +from typing import List from collections import OrderedDict import collections import numpy as np @@ -9,8 +9,11 @@ from tensorflow.keras.layers import Dense, Concatenate, Flatten, Input, Add, BatchNormalization, Dropout from tensorflow.keras.models import Model, load_model, save_model from tensorflow.keras.utils import to_categorical + +from deeptables.models.metainfo import CategoricalColumn from . import deepnets -from .layers import MultiColumnEmbedding, dt_custom_objects +from .layers import MultiColumnEmbedding, dt_custom_objects, VarLenColumnEmbedding +from .metainfo import VarLenCategoricalColumn from ..utils import dt_logging, consts, gpu logger = dt_logging.get_logger() @@ -25,6 +28,7 @@ def __init__(self, config, categorical_columns, continuous_columns, + var_categorical_len_columns=None, # Compatible persisted model model_file=None): # set gpu usage strategy before build model @@ -33,6 +37,7 @@ def __init__(self, self.model_desc = ModelDesc() self.categorical_columns = categorical_columns self.continuous_columns = continuous_columns + self.var_len_categorical_columns = var_categorical_len_columns self.task = task self.num_classes = num_classes self.config = config @@ -73,6 +78,7 @@ def fit(self, X=None, y=None, batch_size=128, epochs=1, verbose=1, callbacks=Non nets=self.config.nets, categorical_columns=self.categorical_columns, continuous_columns=self.continuous_columns, + var_len_categorical_columns=self.var_len_categorical_columns, config=self.config) else: self.model = self.__build_model(task=self.task, @@ -80,6 +86,7 @@ def fit(self, X=None, y=None, batch_size=128, epochs=1, verbose=1, callbacks=Non nets=self.config.nets, categorical_columns=self.categorical_columns, continuous_columns=self.continuous_columns, + var_len_categorical_columns=self.var_len_categorical_columns, config=self.config) logger.info(f'training...') @@ -154,9 +161,22 @@ def release(self): K.clear_session() def __get_model_input(self, X): - input = [X[[c.name for c in self.categorical_columns]].values.astype(consts.DATATYPE_TENSOR_FLOAT)] + \ - [X[c.column_names].values.astype(consts.DATATYPE_TENSOR_FLOAT) for c in self.continuous_columns] - return input + train_data = {} + # add categorical data + if self.categorical_columns is not None and len(self.categorical_columns) > 0: + train_data['input_categorical_vars_all'] = X[[c.name for c in self.categorical_columns]].values.astype(consts.DATATYPE_TENSOR_FLOAT) + + # add continuous data + if self.continuous_columns is not None and len(self.continuous_columns) > 0: + for c in self.continuous_columns: + train_data[c.name] = X[c.column_names].values.astype(consts.DATATYPE_TENSOR_FLOAT) + + # add var len categorical data + if self.var_len_categorical_columns is not None and len(self.var_len_categorical_columns) > 0: + for col in self.var_len_categorical_columns: + train_data[col.name] = np.array(X[col.name].tolist()) + + return train_data def __buld_proxy_model(self, model, output_layers=[], concat_output=False): model.trainable = False @@ -172,11 +192,11 @@ def __buld_proxy_model(self, model, output_layers=[], concat_output=False): proxy.compile(optimizer=model.optimizer, loss=model.loss) return proxy - def __build_model(self, task, num_classes, nets, categorical_columns, continuous_columns, config): + def __build_model(self, task, num_classes, nets, categorical_columns, continuous_columns, var_len_categorical_columns, config): logger.info(f'Building model...') self.model_desc = ModelDesc() - categorical_inputs, continuous_inputs = self.__build_inputs(categorical_columns, continuous_columns) - embeddings = self.__build_embeddings(categorical_columns, categorical_inputs, config.embedding_dropout) + categorical_inputs, continuous_inputs, var_len_categorical_inputs = self.__build_inputs(categorical_columns, continuous_columns, var_len_categorical_columns) + embeddings = self.__build_embeddings(categorical_columns, categorical_inputs, var_len_categorical_columns, var_len_categorical_inputs, config.embedding_dropout) dense_layer = self.__build_denses(continuous_columns, continuous_inputs, config.dense_dropout) flatten_emb_layer = None @@ -190,6 +210,7 @@ def __build_model(self, task, num_classes, nets, categorical_columns, continuous self.model_desc.nets = nets self.model_desc.stacking = config.stacking_op concat_emb_dense = self.__concat_emb_dense(flatten_emb_layer, dense_layer) + # concat_emb_dense = flatten_emb_layer outs = {} for net in nets: logit = deepnets.get(net) @@ -220,7 +241,7 @@ def __build_model(self, task, num_classes, nets, categorical_columns, continuous x = out else: raise ValueError(f'Unexcepted logit output.{outs}') - all_inputs = list(categorical_inputs.values()) + list(continuous_inputs.values()) + all_inputs = list(categorical_inputs.values()) + list(var_len_categorical_inputs.values()) + list(continuous_inputs.values()) output = self.__output_layer(x, task, num_classes, use_bias=self.config.output_use_bias) model = Model(inputs=all_inputs, outputs=output) model = self.__compile_model(model, task, num_classes, config.optimizer, config.loss, config.metrics) @@ -261,33 +282,64 @@ def __concat_emb_dense(self, flatten_emb_layer, dense_layer): self.model_desc.set_concat_embed_dense(x.shape) return x - def __build_inputs(self, categorical_columns, continuous_columns): + def __build_inputs(self, categorical_columns: List[CategoricalColumn], continuous_columns, var_len_categorical_columns: List[VarLenCategoricalColumn]=None): categorical_inputs = OrderedDict() + var_len_categorical_inputs = OrderedDict() continuous_inputs = OrderedDict() - categorical_inputs['all_categorical_vars'] = Input(shape=(len(categorical_columns),), - name='input_categorical_vars_all') - self.model_desc.add_input('all_categorical_vars', len(categorical_columns)) - # for column in categorical_columns: + if categorical_columns is not None and len(categorical_columns) > 0: + categorical_inputs['all_categorical_vars'] = Input(shape=(len(categorical_columns),), + name='input_categorical_vars_all') + self.model_desc.add_input('all_categorical_vars', len(categorical_columns)) + + # make input for var len feature + if var_len_categorical_columns is not None and len(var_len_categorical_columns) > 0: + for col in var_len_categorical_columns: + var_len_categorical_inputs[col.name] = Input(shape=(col.max_elements_length, ), name=col.name) + self.model_desc.add_input(col.name, col.max_elements_length) + for column in continuous_columns: continuous_inputs[column.name] = Input(shape=(column.input_dim,), name=column.name, dtype=column.dtype) self.model_desc.add_input(column.name, column.input_dim) - return categorical_inputs, continuous_inputs + return categorical_inputs, continuous_inputs, var_len_categorical_inputs + + def __construct_var_len_embedding(self, column: VarLenCategoricalColumn, var_len_inputs, embedding_dropout): + input_layer = var_len_inputs[column.name] + var_len_embeddings = VarLenColumnEmbedding(pooling_strategy=column.pooling_strategy, + input_dim=column.vocabulary_size, + output_dim=column.embeddings_output_dim, + dropout_rate=embedding_dropout, + name=consts.LAYER_PREFIX_EMBEDDING + column.name, + embeddings_initializer=self.config.embeddings_initializer, + embeddings_regularizer=self.config.embeddings_regularizer, + activity_regularizer=self.config.embeddings_activity_regularizer + )(input_layer) + return var_len_embeddings + + def __build_embeddings(self, categorical_columns, categorical_inputs, var_len_categorical_columns: List[VarLenCategoricalColumn], var_len_inputs, embedding_dropout): + if 'all_categorical_vars' in categorical_inputs: + input_layer = categorical_inputs['all_categorical_vars'] + input_dims = [column.vocabulary_size for column in categorical_columns] + output_dims = [column.embeddings_output_dim for column in categorical_columns] + embeddings = MultiColumnEmbedding(input_dims, output_dims, embedding_dropout, + name=consts.LAYER_PREFIX_EMBEDDING + 'categorical_vars_all', + embeddings_initializer=self.config.embeddings_initializer, + embeddings_regularizer=self.config.embeddings_regularizer, + activity_regularizer=self.config.embeddings_activity_regularizer, + )(input_layer) + self.model_desc.set_embeddings(input_dims, output_dims, embedding_dropout) + else: + embeddings = [] - def __build_embeddings(self, categorical_columns, categorical_inputs, embedding_dropout): - input_layer = categorical_inputs['all_categorical_vars'] - input_dims = [column.vocabulary_size for column in categorical_columns] - output_dims = [column.embeddings_output_dim for column in categorical_columns] - embeddings = MultiColumnEmbedding(input_dims, output_dims, embedding_dropout, - name=consts.LAYER_PREFIX_EMBEDDING + 'categorical_vars_all', - embeddings_initializer=self.config.embeddings_initializer, - embeddings_regularizer=self.config.embeddings_regularizer, - activity_regularizer=self.config.embeddings_activity_regularizer, - )(input_layer) + # do embedding for var len feature + if var_len_categorical_columns is not None and len(var_len_categorical_columns) > 0: + for c in var_len_categorical_columns: + # todo add var len embedding description + var_len_embedding = self.__construct_var_len_embedding(c, var_len_inputs, embedding_dropout) + embeddings.append(var_len_embedding) - self.model_desc.set_embeddings(input_dims, output_dims, embedding_dropout) return embeddings def __build_denses(self, continuous_columns, continuous_inputs, dense_dropout, use_batchnormalization=False): diff --git a/deeptables/models/deeptable.py b/deeptables/models/deeptable.py index 536ab24..7bb8403 100644 --- a/deeptables/models/deeptable.py +++ b/deeptables/models/deeptable.py @@ -335,6 +335,9 @@ def fit(self, X=None, y=None, batch_size=128, epochs=1, verbose=1, callbacks=Non max_queue_size=10, workers=1, use_multiprocessing=False): logger.info(f'X.Shape={np.shape(X)}, y.Shape={np.shape(y)}, batch_size={batch_size}, config={self.config}') logger.info(f'metrics:{self.config.metrics}') + X_shape = np.shape(X) + if X_shape[1] < 1: + raise ValueError("Input train data should has 1 feature at least.") self.__modelset.clear() X, y = self.preprocessor.fit_transform(X, y) @@ -348,7 +351,8 @@ def fit(self, X=None, y=None, batch_size=128, epochs=1, verbose=1, callbacks=Non callbacks = self.__inject_callbacks(callbacks) model = DeepModel(self.task, self.num_classes, self.config, self.preprocessor.categorical_columns, - self.preprocessor.continuous_columns) + self.preprocessor.continuous_columns, + self.preprocessor.var_len_categorical_columns) history = model.fit(X, y, batch_size=batch_size, epochs=epochs, verbose=verbose, shuffle=shuffle, validation_split=validation_split, validation_data=validation_data, validation_steps=validation_steps, validation_freq=validation_freq, diff --git a/deeptables/models/layers.py b/deeptables/models/layers.py index eeff601..3b612b4 100644 --- a/deeptables/models/layers.py +++ b/deeptables/models/layers.py @@ -913,6 +913,57 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) +class VarLenColumnEmbedding(Embedding): + def __init__(self, pooling_strategy='max', dropout_rate=0., **kwargs): + if pooling_strategy not in ['mean', 'max']: + raise ValueError("Param strategy should is one of mean, max") + self.pooling_strategy = pooling_strategy + self.dropout_rate = dropout_rate # 支持dropout + super(VarLenColumnEmbedding, self).__init__(**kwargs) + + def build(self, input_shape): + super(VarLenColumnEmbedding, self).build(input_shape) # Be sure to call this somewhere! + height = input_shape[1] + if self.pooling_strategy == "mean": + self._pooling_layer = tf.keras.layers.AveragePooling2D(pool_size=(height, 1)) + else: + self._pooling_layer = tf.keras.layers.MaxPooling2D(pool_size=(height, 1)) + + if self.dropout_rate > 0: + self._dropout = SpatialDropout1D(self.dropout_rate) + else: + self._dropout = None + + self.built = True + + def call(self, inputs): + # 1. do embedding + embedding_output = super(VarLenColumnEmbedding, self).call(inputs) + + # 2. add dropout + if self._dropout is not None: + dropout_output = self._dropout(embedding_output) + else: + dropout_output = embedding_output + + # 3. expand dim for polling + inputs_4d = tf.expand_dims(dropout_output, 3) # add channels dim + + # 4. polling + tensor_pooling = self._pooling_layer(inputs_4d) + + # 5. format output + return tf.squeeze(tensor_pooling, 3) + + def compute_mask(self, inputs, mask): + return None + + def get_config(self, ): + config = {'pooling_strategy': self.pooling_strategy} + base_config = super(VarLenColumnEmbedding, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + class BinaryFocalLoss(losses.Loss): """ Binary form of focal loss. diff --git a/deeptables/models/metainfo.py b/deeptables/models/metainfo.py index f0668c2..b9f7e02 100644 --- a/deeptables/models/metainfo.py +++ b/deeptables/models/metainfo.py @@ -49,6 +49,29 @@ def __new__(cls, name, vocabulary_size, embeddings_output_dim=10, dtype='int32', input_name) +class VarLenCategoricalColumn(collections.namedtuple('VarLenCategoricalColumn', + ['name', + 'vocabulary_size', + 'embeddings_output_dim', + 'dtype', + 'input_name', + 'sep', + 'pooling_strategy', + ])): + + def __hash__(self): + return self.name.__hash__() + + def __new__(cls, name, vocabulary_size, embeddings_output_dim=10, dtype='int32', input_name=None, sep="|", pooling_strategy='max'): + if input_name is None: + input_name = consts.INPUT_PREFIX_CAT + name + if embeddings_output_dim == 0: + embeddings_output_dim = int(round(vocabulary_size ** 0.25)) + # max_elements_length need a variable not const + return super(VarLenCategoricalColumn, cls).__new__(cls, name, vocabulary_size, embeddings_output_dim, dtype, + input_name, sep, pooling_strategy) + + class ContinuousColumn(collections.namedtuple('ContinuousColumn', ['name', 'column_names', diff --git a/deeptables/models/preprocessor.py b/deeptables/models/preprocessor.py index 368dfea..64c8fc6 100644 --- a/deeptables/models/preprocessor.py +++ b/deeptables/models/preprocessor.py @@ -15,8 +15,8 @@ from sklearn.pipeline import make_pipeline from sklearn.preprocessing import LabelEncoder from tensorflow.keras.utils import to_categorical -from deeptables.preprocessing.transformer import PassThroughEstimator -from .metainfo import CategoricalColumn, ContinuousColumn +from deeptables.preprocessing.transformer import PassThroughEstimator, VarLenFeatureEncoder, MultiVarLenFeatureEncoder +from .metainfo import CategoricalColumn, ContinuousColumn, VarLenCategoricalColumn from ..preprocessing import MultiLabelEncoder, MultiKBinsDiscretizer, DataFrameWrapper, LgbmLeavesEncoder, \ CategorizeEncoder from ..utils import dt_logging, consts @@ -130,6 +130,7 @@ def __init__(self, config: ModelConfig, cache_home=None, use_cache=False): def reset(self): self.metainfo = None self.categorical_columns = None + self.var_len_categorical_columns = None self.continuous_columns = None self.y_lable_encoder = None self.X_transformers = collections.OrderedDict() @@ -189,6 +190,9 @@ def fit_transform(self, X, y, copy_data=True): X = self._discretization(X) if self.config.apply_gbm_features and y is not None: X = self._apply_gbm_features(X, y) + var_len_categorical_columns = self.config.var_len_categorical_columns + if var_len_categorical_columns is not None and len(var_len_categorical_columns) > 0: + X = self._var_len_encoder(X, var_len_categorical_columns) self.X_transformers['last'] = PassThroughEstimator() @@ -289,6 +293,21 @@ def __prepare_features(self, X): if self.config.cat_exponent >= 1: raise ValueError(f'"cat_expoent" must be less than 1, not {self.config.cat_exponent} .') + var_len_categorical_columns = self.config.var_len_categorical_columns + var_len_column_names = [] + if var_len_categorical_columns is not None and len(var_len_categorical_columns) > 0: + # check items + for v in var_len_categorical_columns: + if not isinstance(v, (tuple, list)) or len(v) != 3: + raise ValueError("Var len column config should be a tuple 3.") + else: + var_len_column_names.append(v[0]) + var_len_col_sep_dict = {v[0]: v[1] for v in var_len_categorical_columns} + var_len_col_pooling_strategy_dict = {v[0]: v[2] for v in var_len_categorical_columns} + else: + var_len_col_sep_dict = {} + var_len_col_pooling_strategy_dict = {} + unique_upper_limit = round(X.shape[0] ** self.config.cat_exponent) for c in X.columns: nunique = X[c].nunique() @@ -300,6 +319,12 @@ def __prepare_features(self, X): if c in self.config.exclude_columns: excluded_vars.append((c, dtype, nunique)) continue + + # handle var len feature + if c in var_len_column_names: + self.__append_var_len_categorical_col(c, nunique, var_len_col_sep_dict[c], var_len_col_pooling_strategy_dict[c]) + continue + if self.config.categorical_columns is not None and isinstance(self.config.categorical_columns, list): if c in self.config.categorical_columns: cat_vars.append((c, dtype, nunique)) @@ -340,12 +365,19 @@ def _imputation(self, X): logger.info('Data imputation...') continuous_vars = self.get_continuous_columns() categorical_vars = self.get_categorical_columns() - ct = ColumnTransformer([ + var_len_categorical_vars = self.get_var_len_categorical_columns() + + transformers = [ ('categorical', SimpleImputer(missing_values=np.nan, strategy='constant'), categorical_vars), ('continuous', SimpleImputer(missing_values=np.nan, strategy='mean'), continuous_vars), - ]) - dfwrapper = DataFrameWrapper(ct, categorical_vars + continuous_vars) + ] + + if len(var_len_categorical_vars) > 0: + transformers.append(('var_len_categorical', SimpleImputer(missing_values=np.nan, strategy='constant'), var_len_categorical_vars),) + + ct = ColumnTransformer(transformers) + dfwrapper = DataFrameWrapper(ct, categorical_vars + continuous_vars + var_len_categorical_vars) X = dfwrapper.fit_transform(X) self.X_transformers['imputation'] = dfwrapper print(f'Imputation taken {time.time() - start}s') @@ -372,6 +404,21 @@ def _discretization(self, X): print(f'Discretization taken {time.time() - start}s') return X + def _var_len_encoder(self, X, var_len_categorical_columns): + start = time.time() + logger.info('Encoder var length feature...') + transformer = MultiVarLenFeatureEncoder(var_len_categorical_columns) + X = transformer.fit_transform(X) + + # update var_len_categorical_columns + for c in self.var_len_categorical_columns: + _encoder: VarLenFeatureEncoder = transformer._encoders[c.name] + c.max_elements_length = _encoder.max_element_length + + self.X_transformers['var_len_encoder'] = transformer + print(f'Encoder taken {time.time() - start}s') + return X + def _apply_gbm_features(self, X, y): start = time.time() logger.info('Extracting GBM features...') @@ -388,6 +435,26 @@ def _apply_gbm_features(self, X, y): print(f'Extracting gbm features taken {time.time() - start}s') return X + def __append_var_len_categorical_col(self, name, voc_size, sep, pooling_strategy): + logger.debug(f'Var len categorical variables {name} appended.') + + if self.config.fixed_embedding_dim: + embedding_output_dim = self.config.embeddings_output_dim if self.config.embeddings_output_dim > 0 else consts.EMBEDDING_OUT_DIM_DEFAULT + else: + embedding_output_dim = 0 + + if self.var_len_categorical_columns is None: + self.var_len_categorical_columns = [] + + vc = \ + VarLenCategoricalColumn(name, + voc_size, + embedding_output_dim if embedding_output_dim > 0 else min(4 * int(pow(voc_size, 0.25)), 20), + sep=sep, + pooling_strategy=pooling_strategy) + + self.var_len_categorical_columns.append(vc) + def __append_categorical_cols(self, cols): logger.debug(f'{len(cols)} categorical variables appended.') @@ -399,23 +466,32 @@ def __append_categorical_cols(self, cols): if self.categorical_columns is None: self.categorical_columns = [] - self.categorical_columns = self.categorical_columns + \ - [CategoricalColumn(name, - voc_size, - embedding_output_dim - if embedding_output_dim > 0 - else min(4 * int(pow(voc_size, 0.25)), 20)) - for name, voc_size in cols] + + if cols is not None and len(cols) > 0: + self.categorical_columns = self.categorical_columns + \ + [CategoricalColumn(name, + voc_size, + embedding_output_dim + if embedding_output_dim > 0 + else min(4 * int(pow(voc_size, 0.25)), 20)) + for name, voc_size in cols] def __append_continuous_cols(self, cols, input_name): if self.continuous_columns is None: self.continuous_columns = [] - self.continuous_columns = self.continuous_columns + [ContinuousColumn(name=input_name, - column_names=[c for c in cols])] + if cols is not None and len(cols) > 0: + self.continuous_columns = self.continuous_columns + [ContinuousColumn(name=input_name, + column_names=[c for c in cols])] def get_categorical_columns(self): return [c.name for c in self.categorical_columns] + def get_var_len_categorical_columns(self): + if self.var_len_categorical_columns is not None: + return [c.name for c in self.var_len_categorical_columns] + else: + return [] + def get_continuous_columns(self): cont_vars = [] for c in self.continuous_columns: diff --git a/deeptables/preprocessing/transformer.py b/deeptables/preprocessing/transformer.py index 1d4fe59..669e178 100644 --- a/deeptables/preprocessing/transformer.py +++ b/deeptables/preprocessing/transformer.py @@ -6,6 +6,7 @@ from sklearn.utils.validation import check_is_fitted from sklearn.utils import column_or_1d from ..utils import dt_logging, consts +from tensorflow.python.keras.preprocessing.sequence import pad_sequences from sklearn.pipeline import Pipeline from sklearn.base import BaseEstimator,TransformerMixin @@ -235,3 +236,67 @@ def transform(self, X): def fit_transform(self, X): self.fit(X) return self.transform(X) + + +class VarLenFeatureEncoder: + + def __init__(self, sep='|'): + self.sep = sep + self.encoder: SafeLabelEncoder = None + self._max_element_length = 0 + + def fit(self, X: pd.Series): + self._max_element_length = 0 # reset + if not isinstance(X, pd.Series): + X = pd.Series(X) + key_set = set() + # flat map + for keys in X.map(lambda _: _.split(self.sep)): + if len(keys) > self._max_element_length: + self._max_element_length = len(keys) + + for key in keys: + key_set.add(key) + lb = SafeLabelEncoder() # fix unseen values + lb.fit(np.array(list(key_set))) + self.encoder = lb + return self + + def transform(self, X: pd.Series): + if self.encoder is None: + raise RuntimeError("Not fit yet .") + + if not isinstance(X, pd.Series): + X = pd.Series(X) + # Notice : input value 0 is a special "padding",so we do not use 0 to encode valid feature for sequence input + data = X.map(lambda _: (self.encoder.transform(_.split(self.sep)) + 1).tolist()) + + return pad_sequences(data, maxlen=self._max_element_length, padding='post', truncating='post').tolist() # cut last elements + + @property + def n_classes(self): + return len(self.encoder.classes_) + + @property + def max_element_length(self): + return self._max_element_length + + +class MultiVarLenFeatureEncoder: + + def __init__(self, features): + self._encoders = {feature[0]: VarLenFeatureEncoder(feature[1]) for feature in features} + + def fit(self, X): + for k, v in self._encoders.items(): + v.fit(X[k]) + return self + + def transform(self, X): + for k, v in self._encoders.items(): + X[k] = v.transform(X[k]) + return X + + def fit_transform(self, X): + self.fit(X) + return self.transform(X) diff --git a/tests/models/model_input_test.py b/tests/models/model_input_test.py new file mode 100644 index 0000000..ce10977 --- /dev/null +++ b/tests/models/model_input_test.py @@ -0,0 +1,78 @@ +# -*- encoding: utf-8 -*- +import pandas as pd +import pytest + +from deeptables.datasets import dsutils +from deeptables.models import deeptable +from deeptables.utils import consts + + +class TestModelInput: + + def setup_class(cls): + cls.df_bank = dsutils.load_bank().sample(frac=0.01) + cls.df_movielens = dsutils.load_movielens() + + def _train_and_asset(self, X, y ,conf: deeptable.ModelConfig): + dt = deeptable.DeepTable(config=conf) + model, history = dt.fit(X, y, validation_split=0.2, epochs=2, batch_size=32) + assert len(model.model.input_names) == 1 + + def test_only_categorical_feature(self): + df = self.df_bank.copy() + X = df[['loan']] + y = df['y'] + conf = deeptable.ModelConfig(nets=['dnn_nets'], + task=consts.TASK_BINARY, + metrics=['accuracy'], + fixed_embedding_dim=True, + embeddings_output_dim=4, + apply_gbm_features=False, + apply_class_weight=True, + earlystopping_patience=3,) + self._train_and_asset(X, y, conf) + + def test_only_continuous_feature(self): + df = self.df_bank.copy() + X = df[['duration']].astype('float32') + y = df['y'] + conf = deeptable.ModelConfig(nets=['dnn_nets'], + task=consts.TASK_BINARY, + metrics=['accuracy'], + fixed_embedding_dim=True, + embeddings_output_dim=4, + apply_gbm_features=False, + apply_class_weight=True, + earlystopping_patience=3,) + self._train_and_asset(X, y, conf) + + def test_only_var_len_categorical_feature(self): + df:pd.DataFrame = self.df_movielens.copy() + X = df[['genres']] + y = df['rating'] + conf = deeptable.ModelConfig(nets=['dnn_nets'], + task=consts.TASK_REGRESSION, + metrics=['mse'], + fixed_embedding_dim=True, + embeddings_output_dim=4, + apply_gbm_features=False, + apply_class_weight=True, + earlystopping_patience=3,) + self._train_and_asset(X, y, conf) + + def test_no_input(self): + df:pd.DataFrame = self.df_movielens.copy() + X = pd.DataFrame() + y = df['rating'] + conf = deeptable.ModelConfig(nets=['dnn_nets'], + task=consts.TASK_REGRESSION, + metrics=['mse'], + fixed_embedding_dim=True, + embeddings_output_dim=4, + apply_gbm_features=False, + apply_class_weight=True, + earlystopping_patience=3,) + dt = deeptable.DeepTable(config=conf) + with pytest.raises(ValueError) as err_info: + dt.fit(X, y, validation_split=0.2, epochs=2, batch_size=32) + print(err_info) diff --git a/tests/models/test_var_len_categorical.py b/tests/models/test_var_len_categorical.py new file mode 100644 index 0000000..52c608d --- /dev/null +++ b/tests/models/test_var_len_categorical.py @@ -0,0 +1,49 @@ +# -*- encoding: utf-8 -*- +import numpy as np +from sklearn.model_selection import train_test_split +from deeptables.models import deeptable +from deeptables.preprocessing.transformer import MultiVarLenFeatureEncoder +from deeptables.utils import consts +from deeptables.datasets import dsutils + + +class TestVarLenCategoricalFeature: + + def setup_class(cls): + cls.df = dsutils.load_movielens().drop(['timestamp', "title"], axis=1) + + def test_encoder(self): + df = self.df.copy() + df['genres_copy'] = df['genres'] + + multi_encoder = MultiVarLenFeatureEncoder([('genres', '|'), ('genres_copy', '|'), ]) + result_df = multi_encoder.fit_transform(df) + + assert multi_encoder._encoders['genres'].max_element_length > 0 + assert multi_encoder._encoders['genres_copy'].max_element_length > 0 + + shape = np.array(result_df['genres'].tolist()).shape + assert shape[1] == multi_encoder._encoders['genres'].max_element_length + + def test_var_categorical_feature(self): + X = self.df.copy() + y = X.pop('rating').values.astype('float32') + + conf = deeptable.ModelConfig(nets=['dnn_nets'], + task=consts.TASK_REGRESSION, + categorical_columns=["movie_id", "user_id", "gender", "occupation", "zip", "title", "age"], + metrics=['mse'], + fixed_embedding_dim=True, + embeddings_output_dim=4, + apply_gbm_features=False, + apply_class_weight=True, + earlystopping_patience=5, + var_len_categorical_columns=[('genres', "|", "max")]) + + dt = deeptable.DeepTable(config=conf) + + X_train, X_validation, y_train, y_validation = train_test_split(X, y, test_size=0.2) + + model, history = dt.fit(X_train, y_train, validation_data=(X_validation, y_validation), epochs=10, batch_size=32) + + assert 'genres' in model.model.input_names