diff --git a/.gitignore b/.gitignore index e98ba75..ea47f34 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -data/* log/* *.log @@ -42,3 +41,5 @@ luac.out *.i*86 *.x86_64 *.hex + +*.conv diff --git a/163qr.jpeg b/163qr.jpeg new file mode 100644 index 0000000..ad040e5 Binary files /dev/null and b/163qr.jpeg differ diff --git a/README.md b/README.md deleted file mode 100644 index 42ba7af..0000000 --- a/README.md +++ /dev/null @@ -1,140 +0,0 @@ -# Neural Conversational Model in Torch - -This is an attempt at implementing [Sequence to Sequence Learning with Neural Networks (seq2seq)](http://arxiv.org/abs/1409.3215) and reproducing the results in [A Neural Conversational Model](http://arxiv.org/abs/1506.05869) (aka the Google chatbot). - -The Google chatbot paper [became famous](http://www.sciencealert.com/google-s-ai-bot-thinks-the-purpose-of-life-is-to-live-forever) after cleverly answering a few philosophical questions, such as: - -> **Human:** What is the purpose of living? -> **Machine:** To live forever. - -## How it works - -The model is based on two [LSTM](https://en.wikipedia.org/wiki/Long_short-term_memory) layers. One for encoding the input sentence into a "thought vector", and another for decoding that vector into a response. This model is called Sequence-to-sequence or seq2seq. - -![seq2seq](https://4.bp.blogspot.com/-aArS0l1pjHQ/Vjj71pKAaEI/AAAAAAAAAxE/Nvy1FSbD_Vs/s640/2TFstaticgraphic_alt-01.png) -_Source: http://googleresearch.blogspot.ca/2015/11/computer-respond-to-this-email.html_ - -In this experiment, we train the seq2seq model with movie dialogs from the [Cornell Movie-Dialogs Corpus](http://www.mpi-sws.org/~cristian/Cornell_Movie-Dialogs_Corpus.html). The lines are shortened to the first sentence. - -## Sample conversation - -Here's a sample conversation after training for 20 epoch with 50000 examples, using the following command: - -```sh -th train.lua --cuda --dataset 50000 --hiddenSize 1000 -``` - -(Took 3 days to train on my GeForce GTX 780M.) - -For OpenCL, use `--opencl` instead of `--cuda`. To train on CPU, don't provide any of those two. - -> **me:** Hello? -> **bot:** Hi. -> -> **me:** How are you? -> **bot:** I'm fine. -> -> **me:** What's your name? -> **bot:** It's hard to describe. -> -> **me:** How so? -> **bot:** I'm not sure. -> -> **me:** What color is the sky? -> **bot:** It's blue. -> -> **me:** What is your job? -> **bot:** It's not that i'm a fucking werewolf! -> -> **me:** What is the purpose of life? -> **bot:** A gift. -> -> **me:** Are you intelligent? -> **bot:** Yes, well... -> -> **me:** Are you a machine? -> **bot:** That's a lie. -> -> **me:** Are you human? -> **bot:** No, i'm not. -> -> **me:** What are you? -> **bot:** I'm not sure. -> -> **me:** Do you plan on taking over the world? -> **bot:** No, i don't. - -Phew! That was close. Good thing I didn't train it on the full dataset. Please experiment responsibly. - -_(Disclaimer: nonsensical responses have been removed.)_ - -## Installing - -1. [Install Torch](http://torch.ch/docs/getting-started.html). -2. Install the following additional Lua libs: - - ```sh - luarocks install nn - luarocks install rnn - luarocks install penlight - ``` - - To train with CUDA install the latest CUDA drivers, toolkit and run: - - ```sh - luarocks install cutorch - luarocks install cunn - ``` - - To train with opencl install the lastest Opencl torch lib: - - ```sh - luarocks install cltorch - luarocks install clnn - ``` - -3. Download the [Cornell Movie-Dialogs Corpus](http://www.mpi-sws.org/~cristian/Cornell_Movie-Dialogs_Corpus.html) and extract all the files into data/cornell_movie_dialogs. - -## Training - -```sh -th train.lua [-h / options] -``` - -Use the `--dataset NUMBER` option to control the size of the dataset. Training on the full dataset takes about 5h for a single epoch. - -The model will be saved to `data/model.t7` after each epoch if it has improved (error decreased). - -## Testing - -To load the model and have a conversation: - -```sh -th -i eval.lua --cuda # Skip --cuda if you didn't train with it -# ... -th> say "Hello." -``` - -## License - -MIT License - -Copyright (c) 2016 Marc-Andre Cournoyer - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file diff --git a/a.lua b/a.lua new file mode 100644 index 0000000..ac5cf28 --- /dev/null +++ b/a.lua @@ -0,0 +1,473 @@ +--- Lexical scanner for creating a sequence of tokens from text. +-- `lexer.scan(s)` returns an iterator over all tokens found in the +-- string `s`. This iterator returns two values, a token type string +-- (such as 'string' for quoted string, 'iden' for identifier) and the value of the +-- token. +-- +-- Versions specialized for Lua and C are available; these also handle block comments +-- and classify keywords as 'keyword' tokens. For example: +-- +-- > s = 'for i=1,n do' +-- > for t,v in lexer.lua(s) do print(t,v) end +-- keyword for +-- iden i +-- = = +-- number 1 +-- , , +-- iden n +-- keyword do +-- +-- See the Guide for further @{06-data.md.Lexical_Scanning|discussion} +-- @module pl.lexer + +local yield,wrap = coroutine.yield,coroutine.wrap +local strfind = string.find +local strsub = string.sub +local append = table.insert + +local function assert_arg(idx,val,tp) + if type(val) ~= tp then + error("argument "..idx.." must be "..tp, 2) + end +end + +local lexer = {} + +local NUMBER1 = '^[%+%-]?%d+%.?%d*[eE][%+%-]?%d+' +local NUMBER2 = '^[%+%-]?%d+%.?%d*' +local NUMBER3 = '^0x[%da-fA-F]+' +local NUMBER4 = '^%d+%.?%d*[eE][%+%-]?%d+' +local NUMBER5 = '^%d+%.?%d*' +local IDEN = '^[%a_][%w_]*' +local WSPACE = '^%s+' +local STRING1 = "^(['\"])%1" -- empty string +local STRING2 = [[^(['"])(\*)%2%1]] +local STRING3 = [[^(['"]).-[^\](\*)%2%1]] +local CHAR1 = "^''" +local CHAR2 = [[^'(\*)%1']] +local CHAR3 = [[^'.-[^\](\*)%1']] +local PREPRO = '^#.-[^\\]\n' + +local plain_matches,lua_matches,cpp_matches,lua_keyword,cpp_keyword + +local function tdump(tok) + return yield(tok,tok) +end + +local function ndump(tok,options) + if options and options.number then + tok = tonumber(tok) + end + return yield("number",tok) +end + +-- regular strings, single or double quotes; usually we want them +-- without the quotes +local function sdump(tok,options) + if options and options.string then + tok = tok:sub(2,-2) + end + return yield("string",tok) +end + +-- long Lua strings need extra work to get rid of the quotes +local function sdump_l(tok,options,findres) + if options and options.string then + local quotelen = 3 + if findres[3] then + quotelen = quotelen + findres[3]:len() + end + tok = tok:sub(quotelen, -quotelen) + if tok:sub(1, 1) == "\n" then + tok = tok:sub(2) + end + end + return yield("string",tok) +end + +local function chdump(tok,options) + if options and options.string then + tok = tok:sub(2,-2) + end + return yield("char",tok) +end + +local function cdump(tok) + return yield('comment',tok) +end + +local function wsdump (tok) + return yield("space",tok) +end + +local function pdump (tok) + return yield('prepro',tok) +end + +local function plain_vdump(tok) + return yield("iden",tok) +end + +local function lua_vdump(tok) + if lua_keyword[tok] then + return yield("keyword",tok) + else + return yield("iden",tok) + end +end + +local function cpp_vdump(tok) + if cpp_keyword[tok] then + return yield("keyword",tok) + else + return yield("iden",tok) + end +end + +--- create a plain token iterator from a string or file-like object. +-- @tparam string|file s a string or a file-like object with `:read()` method returning lines. +-- @tab matches an optional match table - array of token descriptions. +-- A token is described by a `{pattern, action}` pair, where `pattern` should match +-- token body and `action` is a function called when a token of described type is found. +-- @tab[opt] filter a table of token types to exclude, by default `{space=true}` +-- @tab[opt] options a table of options; by default, `{number=true,string=true}`, +-- which means convert numbers and strip string quotes. +function lexer.scan(s,matches,filter,options) + local file = type(s) ~= 'string' and s + filter = filter or {space=true} + options = options or {number=true,string=true} + if filter then + if filter.space then filter[wsdump] = true end + if filter.comments then + filter[cdump] = true + end + end + if not matches then + if not plain_matches then + plain_matches = { + {WSPACE,wsdump}, + {NUMBER3,ndump}, + {IDEN,plain_vdump}, + {NUMBER1,ndump}, + {NUMBER2,ndump}, + {STRING1,sdump}, + {STRING2,sdump}, + {STRING3,sdump}, + {'^.',tdump} + } + end + matches = plain_matches + end + local function lex() + local line_nr = 0 + local next_line = file and file:read() + local sz = file and 0 or #s + local idx = 1 + + while true do + if idx > sz then + if file then + if not next_line then return end + s = next_line + line_nr = line_nr + 1 + next_line = file:read() + if next_line then + s = s .. '\n' + end + idx, sz = 1, #s + else + return + end + end + + for _,m in ipairs(matches) do + local pat = m[1] + local fun = m[2] + local findres = {strfind(s,pat,idx)} + local i1, i2 = findres[1], findres[2] + if i1 then + local tok = strsub(s,i1,i2) + idx = i2 + 1 + local res + if not (filter and filter[fun]) then + lexer.finished = idx > sz + res = fun(tok, options, findres) + end + if res then + local tp = type(res) + -- insert a token list + if tp == 'table' then + yield('','') + for _,t in ipairs(res) do + yield(t[1],t[2]) + end + elseif tp == 'string' then -- or search up to some special pattern + i1,i2 = strfind(s,res,idx) + if i1 then + tok = strsub(s,i1,i2) + idx = i2 + 1 + yield('',tok) + else + yield('','') + idx = sz + 1 + end + else + yield(line_nr,idx) + end + end + + break + end + end + end + end + return wrap(lex) +end + +local function isstring (s) + return type(s) == 'string' +end + +--- insert tokens into a stream. +-- @param tok a token stream +-- @param a1 a string is the type, a table is a token list and +-- a function is assumed to be a token-like iterator (returns type & value) +-- @string a2 a string is the value +function lexer.insert (tok,a1,a2) + if not a1 then return end + local ts + if isstring(a1) and isstring(a2) then + ts = {{a1,a2}} + elseif type(a1) == 'function' then + ts = {} + for t,v in a1() do + append(ts,{t,v}) + end + else + ts = a1 + end + tok(ts) +end + +--- get everything in a stream upto a newline. +-- @param tok a token stream +-- @return a string +function lexer.getline (tok) + local t,v = tok('.-\n') + return v +end + +--- get current line number. +-- Only available if the input source is a file-like object. +-- @param tok a token stream +-- @return the line number and current column +function lexer.lineno (tok) + return tok(0) +end + +--- get the rest of the stream. +-- @param tok a token stream +-- @return a string +function lexer.getrest (tok) + local t,v = tok('.+') + return v +end + +--- get the Lua keywords as a set-like table. +-- So `res["and"]` etc would be `true`. +-- @return a table +function lexer.get_keywords () + if not lua_keyword then + lua_keyword = { + ["and"] = true, ["break"] = true, ["do"] = true, + ["else"] = true, ["elseif"] = true, ["end"] = true, + ["false"] = true, ["for"] = true, ["function"] = true, + ["if"] = true, ["in"] = true, ["local"] = true, ["nil"] = true, + ["not"] = true, ["or"] = true, ["repeat"] = true, + ["return"] = true, ["then"] = true, ["true"] = true, + ["until"] = true, ["while"] = true + } + end + return lua_keyword +end + +--- create a Lua token iterator from a string or file-like object. +-- Will return the token type and value. +-- @string s the string +-- @tab[opt] filter a table of token types to exclude, by default `{space=true,comments=true}` +-- @tab[opt] options a table of options; by default, `{number=true,string=true}`, +-- which means convert numbers and strip string quotes. +function lexer.lua(s,filter,options) + filter = filter or {space=true,comments=true} + lexer.get_keywords() + if not lua_matches then + lua_matches = { + {WSPACE,wsdump}, + {NUMBER3,ndump}, + {IDEN,lua_vdump}, + {NUMBER4,ndump}, + {NUMBER5,ndump}, + {STRING1,sdump}, + {STRING2,sdump}, + {STRING3,sdump}, + {'^%-%-%[(=*)%[.-%]%1%]',cdump}, + {'^%-%-.-\n',cdump}, + {'^%[(=*)%[.-%]%1%]',sdump_l}, + {'^==',tdump}, + {'^~=',tdump}, + {'^<=',tdump}, + {'^>=',tdump}, + {'^%.%.%.',tdump}, + {'^%.%.',tdump}, + {'^.',tdump} + } + end + return lexer.scan(s,lua_matches,filter,options) +end + +--- create a C/C++ token iterator from a string or file-like object. +-- Will return the token type type and value. +-- @string s the string +-- @tab[opt] filter a table of token types to exclude, by default `{space=true,comments=true}` +-- @tab[opt] options a table of options; by default, `{number=true,string=true}`, +-- which means convert numbers and strip string quotes. +function lexer.cpp(s,filter,options) + filter = filter or {space=true,comments=true} + if not cpp_keyword then + cpp_keyword = { + ["class"] = true, ["break"] = true, ["do"] = true, ["sizeof"] = true, + ["else"] = true, ["continue"] = true, ["struct"] = true, + ["false"] = true, ["for"] = true, ["public"] = true, ["void"] = true, + ["private"] = true, ["protected"] = true, ["goto"] = true, + ["if"] = true, ["static"] = true, ["const"] = true, ["typedef"] = true, + ["enum"] = true, ["char"] = true, ["int"] = true, ["bool"] = true, + ["long"] = true, ["float"] = true, ["true"] = true, ["delete"] = true, + ["double"] = true, ["while"] = true, ["new"] = true, + ["namespace"] = true, ["try"] = true, ["catch"] = true, + ["switch"] = true, ["case"] = true, ["extern"] = true, + ["return"] = true,["default"] = true,['unsigned'] = true,['signed'] = true, + ["union"] = true, ["volatile"] = true, ["register"] = true,["short"] = true, + } + end + if not cpp_matches then + cpp_matches = { + {WSPACE,wsdump}, + {PREPRO,pdump}, + {NUMBER3,ndump}, + {IDEN,cpp_vdump}, + {NUMBER4,ndump}, + {NUMBER5,ndump}, + {CHAR1,chdump}, + {CHAR2,chdump}, + {CHAR3,chdump}, + {STRING1,sdump}, + {STRING2,sdump}, + {STRING3,sdump}, + {'^//.-\n',cdump}, + {'^/%*.-%*/',cdump}, + {'^==',tdump}, + {'^!=',tdump}, + {'^<=',tdump}, + {'^>=',tdump}, + {'^->',tdump}, + {'^&&',tdump}, + {'^||',tdump}, + {'^%+%+',tdump}, + {'^%-%-',tdump}, + {'^%+=',tdump}, + {'^%-=',tdump}, + {'^%*=',tdump}, + {'^/=',tdump}, + {'^|=',tdump}, + {'^%^=',tdump}, + {'^::',tdump}, + {'^.',tdump} + } + end + return lexer.scan(s,cpp_matches,filter,options) +end + +--- get a list of parameters separated by a delimiter from a stream. +-- @param tok the token stream +-- @string[opt=')'] endtoken end of list. Can be '\n' +-- @string[opt=','] delim separator +-- @return a list of token lists. +function lexer.get_separated_list(tok,endtoken,delim) + endtoken = endtoken or ')' + delim = delim or ',' + local parm_values = {} + local level = 1 -- used to count ( and ) + local tl = {} + local function tappend (tl,t,val) + val = val or t + append(tl,{t,val}) + end + local is_end + if endtoken == '\n' then + is_end = function(t,val) + return t == 'space' and val:find '\n' + end + else + is_end = function (t) + return t == endtoken + end + end + local token,value + while true do + token,value=tok() + if not token then return nil,'EOS' end -- end of stream is an error! + if is_end(token,value) and level == 1 then + append(parm_values,tl) + break + elseif token == '(' then + level = level + 1 + tappend(tl,'(') + elseif token == ')' then + level = level - 1 + if level == 0 then -- finished with parm list + append(parm_values,tl) + break + else + tappend(tl,')') + end + elseif token == delim and level == 1 then + append(parm_values,tl) -- a new parm + tl = {} + else + tappend(tl,token,value) + end + end + return parm_values,{token,value} +end + +--- get the next non-space token from the stream. +-- @param tok the token stream. +function lexer.skipws (tok) + local t,v = tok() + while t == 'space' do + t,v = tok() + end + return t,v +end + +local skipws = lexer.skipws + +--- get the next token, which must be of the expected type. +-- Throws an error if this type does not match! +-- @param tok the token stream +-- @string expected_type the token type +-- @bool no_skip_ws whether we should skip whitespace +function lexer.expecting (tok,expected_type,no_skip_ws) + assert_arg(1,tok,'function') + assert_arg(2,expected_type,'string') + local t,v + if no_skip_ws then + t,v = tok() + else + t,v = skipws(tok) + end + if t ~= expected_type then error ("expecting "..expected_type,2) end + return v +end + +return lexer diff --git a/a.png b/a.png new file mode 100755 index 0000000..664f314 Binary files /dev/null and b/a.png differ diff --git a/b.png b/b.png new file mode 100755 index 0000000..e297c76 Binary files /dev/null and b/b.png differ diff --git a/c.png b/c.png new file mode 100755 index 0000000..fa06401 Binary files /dev/null and b/c.png differ diff --git a/cornell_movie_dialogs.lua b/cornell_movie_dialogs.lua old mode 100644 new mode 100755 index cf298d5..389f2c5 --- a/cornell_movie_dialogs.lua +++ b/cornell_movie_dialogs.lua @@ -2,71 +2,63 @@ local CornellMovieDialogs = torch.class("neuralconvo.CornellMovieDialogs") local stringx = require "pl.stringx" local xlua = require "xlua" -local function parsedLines(file, fields) - local f = assert(io.open(file, 'r')) - - return function() - local line = f:read("*line") - - if line == nil then - f:close() - return - end - - local values = stringx.split(line, " +++$+++ ") - local t = {} - - for i,field in ipairs(fields) do - t[field] = values[i] - end - - return t - end -end - function CornellMovieDialogs:__init(dir) self.dir = dir end -local MOVIE_LINES_FIELDS = {"lineID","characterID","movieID","character","text"} -local MOVIE_CONVERSATIONS_FIELDS = {"character1ID","character2ID","movieID","utteranceIDs"} -local TOTAL_LINES = 387810 - -local function progress(c) - if c % 10000 == 0 then - xlua.progress(c, TOTAL_LINES) - end -end function CornellMovieDialogs:load() local lines = {} local conversations = {} - local count = 0 + local count = 1 print("-- Parsing Cornell movie dialogs data set ...") - - for line in parsedLines(self.dir .. "/movie_lines.txt", MOVIE_LINES_FIELDS) do - lines[line.lineID] = line - line.lineID = nil - -- Remove unused fields - line.characterID = nil - line.movieID = nil - count = count + 1 - progress(count) + + + local f = assert(io.open('../xiaohuangji50w_fenciA.conv', 'r')) + + while true do + local line = f:read("*line") + if line == nil then + f:close() + break + end + + lines[count] = line + count = count + 1 end - for conv in parsedLines(self.dir .. "/movie_conversations.txt", MOVIE_CONVERSATIONS_FIELDS) do - local conversation = {} - local lineIDs = stringx.split(conv.utteranceIDs:sub(3, -3), "', '") - for i,lineID in ipairs(lineIDs) do - table.insert(conversation, lines[lineID]) + print("Total lines = "..count) + local tmpconv = nil + + local TOTAL = #lines + local count = 0 + + for i, line in ipairs(lines) do + --print(i..' '..line) + if string.sub(line, 0, 1) == "E" then + + if tmpconv ~= nil then + --print('new conv'..#tmpconv) + table.insert(conversations, tmpconv) + end + --print('e make the tmpconv') + tmpconv = {} + + end + + if string.sub(line, 0, 1) == "M" then + --print('insert into conv') + local tmpl = string.sub(line, 3, #line) + --print(tmpl) + table.insert(tmpconv, tmpl) end - table.insert(conversations, conversation) + count = count + 1 - progress(count) + if count%1000 == 0 then + xlua.progress(count, TOTAL) + end end - xlua.progress(TOTAL_LINES, TOTAL_LINES) - return conversations end diff --git a/cornell_movie_dialogs2.lua b/cornell_movie_dialogs2.lua new file mode 100644 index 0000000..cf298d5 --- /dev/null +++ b/cornell_movie_dialogs2.lua @@ -0,0 +1,72 @@ +local CornellMovieDialogs = torch.class("neuralconvo.CornellMovieDialogs") +local stringx = require "pl.stringx" +local xlua = require "xlua" + +local function parsedLines(file, fields) + local f = assert(io.open(file, 'r')) + + return function() + local line = f:read("*line") + + if line == nil then + f:close() + return + end + + local values = stringx.split(line, " +++$+++ ") + local t = {} + + for i,field in ipairs(fields) do + t[field] = values[i] + end + + return t + end +end + +function CornellMovieDialogs:__init(dir) + self.dir = dir +end + +local MOVIE_LINES_FIELDS = {"lineID","characterID","movieID","character","text"} +local MOVIE_CONVERSATIONS_FIELDS = {"character1ID","character2ID","movieID","utteranceIDs"} +local TOTAL_LINES = 387810 + +local function progress(c) + if c % 10000 == 0 then + xlua.progress(c, TOTAL_LINES) + end +end + +function CornellMovieDialogs:load() + local lines = {} + local conversations = {} + local count = 0 + + print("-- Parsing Cornell movie dialogs data set ...") + + for line in parsedLines(self.dir .. "/movie_lines.txt", MOVIE_LINES_FIELDS) do + lines[line.lineID] = line + line.lineID = nil + -- Remove unused fields + line.characterID = nil + line.movieID = nil + count = count + 1 + progress(count) + end + + for conv in parsedLines(self.dir .. "/movie_conversations.txt", MOVIE_CONVERSATIONS_FIELDS) do + local conversation = {} + local lineIDs = stringx.split(conv.utteranceIDs:sub(3, -3), "', '") + for i,lineID in ipairs(lineIDs) do + table.insert(conversation, lines[lineID]) + end + table.insert(conversations, conversation) + count = count + 1 + progress(count) + end + + xlua.progress(TOTAL_LINES, TOTAL_LINES) + + return conversations +end diff --git a/data/fate2.jpeg b/data/fate2.jpeg new file mode 100644 index 0000000..f14a7af Binary files /dev/null and b/data/fate2.jpeg differ diff --git a/data/qq2.jpeg b/data/qq2.jpeg new file mode 100644 index 0000000..5a105b5 Binary files /dev/null and b/data/qq2.jpeg differ diff --git a/data/qqun.png b/data/qqun.png new file mode 100644 index 0000000..a47f5cf Binary files /dev/null and b/data/qqun.png differ diff --git a/dataset.lua b/dataset.lua old mode 100644 new mode 100755 index 4664b0e..8e666a8 --- a/dataset.lua +++ b/dataset.lua @@ -41,6 +41,7 @@ function DataSet:load(loader) local filename = "data/vocab.t7" if path.exists(filename) then + --if false then print("Loading vocabulary from " .. filename .. " ...") local data = torch.load(filename) self.word2id = data.word2id @@ -50,6 +51,7 @@ function DataSet:load(loader) self.eosToken = data.eosToken self.unknownToken = data.unknownToken self.examplesCount = data.examplesCount + --print(self.word2id) else print("" .. filename .. " not found") self:visit(loader:load()) @@ -81,6 +83,7 @@ function DataSet:visit(conversations) local total = self.loadFirst or #conversations * 2 for i, conversation in ipairs(conversations) do + --print(i) if i > total then break end self:visitConversation(conversation) xlua.progress(i, total) @@ -88,12 +91,14 @@ function DataSet:visit(conversations) -- Revisit from the perspective of 2nd character for i, conversation in ipairs(conversations) do + --print(i) if #conversations + i > total then break end self:visitConversation(conversation, 2) xlua.progress(#conversations + i, total) end print("-- Removing low frequency words") + print("sfsgsdfgdf") for i, datum in ipairs(self.examples) do self:removeLowFreqWords(datum[1]) @@ -223,13 +228,15 @@ end function DataSet:visitConversation(lines, start) start = start or 1 + --print("conv lines "..#lines) + for i = start, #lines, 2 do local input = lines[i] local target = lines[i+1] if target then - local inputIds = self:visitText(input.text) - local targetIds = self:visitText(target.text, 2) + local inputIds = self:visitText(input) + local targetIds = self:visitText(target, 2) if inputIds and targetIds then -- Revert inputs @@ -248,18 +255,31 @@ function DataSet:visitText(text, additionalTokens) local words = {} additionalTokens = additionalTokens or 0 - if text == "" then + if text == "" or text == nil then + print "zero text" return end + --print(text) + local values = stringx.split(text, "/") + for i, word in ipairs(values) do + --print("spword:"..word) + table.insert(words, self:makeWordId(word)) + if #words >= self.maxExampleLen - additionalTokens then + break + end + end + +--[[ for t, word in tokenizer.tokenize(text) do + print(word) table.insert(words, self:makeWordId(word)) -- Only keep the first sentence if t == "endpunct" or #words >= self.maxExampleLen - additionalTokens then break end end - +]]-- if #words == 0 then return end @@ -268,13 +288,15 @@ function DataSet:visitText(text, additionalTokens) end function DataSet:makeWordId(word) - word = word:lower() - + --word = word:lower() + --print(word) local id = self.word2id[word] if id then self.wordFreq[word] = self.wordFreq[word] + 1 + --print("more freq > 1") else + --print("to dict word = "..word) self.wordsCount = self.wordsCount + 1 id = self.wordsCount self.id2word[id] = word diff --git a/eval-server.lua b/eval-server.lua new file mode 100755 index 0000000..fa4cd04 --- /dev/null +++ b/eval-server.lua @@ -0,0 +1,155 @@ +require 'neuralconvo' +local tokenizer = require "tokenizer" +local list = require "pl.List" +local options = {} + +if dataset == nil then + cmd = torch.CmdLine() + cmd:text('Options:') + cmd:option('--cuda', false, 'use CUDA. Training must be done on CUDA') + cmd:option('--opencl', false, 'use OpenCL. Training must be done on OpenCL') + cmd:option('--debug', false, 'show debug info') + cmd:text() + options = cmd:parse(arg) + + -- Data + dataset = neuralconvo.DataSet() + + -- Enabled CUDA + if options.cuda then + require 'cutorch' + require 'cunn' + elseif options.opencl then + require 'cltorch' + require 'clnn' + end +end + +if model == nil then + print("-- Loading model") + model = torch.load("data/model.t7") +end + +-- Word IDs to sentence +function pred2sent(wordIds, i) + local words = {} + i = i or 1 + + for _, wordId in ipairs(wordIds) do + local word = dataset.id2word[wordId[i]] + --print(wordId[i]..word) + table.insert(words, word) + end + + return tokenizer.join(words) +end + +function printProbabilityTable(wordIds, probabilities, num) + print(string.rep("-", num * 22)) + + for p, wordId in ipairs(wordIds) do + local line = "| " + for i = 1, num do + local word = dataset.id2word[wordId[i]] + line = line .. string.format("%-10s(%4d%%)", word, probabilities[p][i] * 100) .. " | " + end + print(line) + end + + print(string.rep("-", num * 22)) +end + +function say(text) + local wordIds = {} + + + + --print(text) + local values = {} + for w in text:gmatch("[\33-\127\192-\255]+[\128-\191]*") do + table.insert(values, w) + end + + for i, word in ipairs(values) do + local id = dataset.word2id[word] or dataset.unknownToken + --print(i.." "..word.." "..id) + + table.insert(wordIds, id) + + end + +--[[ + for t, word in tokenizer.tokenize(text) do + local id = dataset.word2id[word:lower()] or dataset.unknownToken + table.insert(wordIds, id) + end +]]-- + + local input = torch.Tensor(list.reverse(wordIds)) + local wordIds, probabilities = model:eval(input) + + local ret = pred2sent(wordIds) + print(">> " .. ret) + + if options.debug then + printProbabilityTable(wordIds, probabilities, 4) + end + + return ret + +end + + +--[[ http server using ASyNC]]-- + + function unescape (s) + s = string.gsub(s, "+", " ") + s = string.gsub(s, "%%(%x%x)", function (h) + return string.char(tonumber(h, 16)) + end) + return s + end + + +local async = require 'async' +require('pl.text').format_operator() + +async.http.listen('http://0.0.0.0:8082/', function(req,res) + print('request:',req) + local resp + + if req.url.path == '/' and req.url.query ~= nil and #req.url.query > 0 then + + local text_in = unescape(req.url.query) + print(text_in) + local ret = say(text_in) + resp = [[${data}]] % {data = ret} + + else + resp = 'Oops~ This is a wrong place, please goto here!' + + end + + -- if req.url.path == '/test' then + -- resp = [[ + --

You requested route /test

+ -- ]] + -- else + -- -- Produce a random story: + -- resp = [[ + --

From my server

+ --

It's working!

+ --

Randomly generated number: ${number}

+ --

A variable in the global scope: ${ret}

+ -- ]] % { + -- number = math.random(), + -- ret = ret + -- } + -- end + + res(resp, {['Content-Type']='text/html; charset=UTF-8'}) +end) + +print('server listening to port 8082') + +async.go() \ No newline at end of file diff --git a/eval.lua b/eval.lua old mode 100644 new mode 100755 index 86d1772..a489bc6 --- a/eval.lua +++ b/eval.lua @@ -37,6 +37,7 @@ function pred2sent(wordIds, i) for _, wordId in ipairs(wordIds) do local word = dataset.id2word[wordId[i]] + --print(wordId[i]..word) table.insert(words, word) end @@ -58,17 +59,37 @@ function printProbabilityTable(wordIds, probabilities, num) print(string.rep("-", num * 22)) end + function say(text) local wordIds = {} + + + --print(text) + local values = {} + for w in text:gmatch("[\33-\127\192-\255]+[\128-\191]*") do + table.insert(values, w) + end + -- print(values) + for i, word in ipairs(values) do + local id = dataset.word2id[word] or dataset.unknownToken + -- print(i.." "..word.." "..id) + + table.insert(wordIds, id) + + end + +--[[ for t, word in tokenizer.tokenize(text) do local id = dataset.word2id[word:lower()] or dataset.unknownToken table.insert(wordIds, id) end +]]-- local input = torch.Tensor(list.reverse(wordIds)) local wordIds, probabilities = model:eval(input) + print(">> " .. pred2sent(wordIds)) if options.debug then diff --git a/lstm_text_generation.py b/lstm_text_generation.py new file mode 100755 index 0000000..dde8abd --- /dev/null +++ b/lstm_text_generation.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +'''(C) 2016 rust + +Example script to generate text from Nietzsche's writings. + +At least 20 epochs are required before the generated text +starts sounding coherent. + +It is recommended to run this script on GPU, as recurrent +networks are quite computationally intensive. + +If you try this script on new data, make sure your corpus +has at least ~100k characters. ~1M is better. + + + +''' + +from __future__ import print_function +from keras.models import Sequential +from keras.layers import Dense, Activation, Dropout +from keras.layers import LSTM +from keras.utils.data_utils import get_file +import numpy as np +import random +import sys + + +#pandas 安装出错,所以直接从csv中解析 +text_lines = open('000001.csv').readlines()[1:] +print ('day count = ' + str(len(text_lines))) + +print (text_lines[0]) + +sources_all = [] +targets_all = [] + +for line in reversed(text_lines): + #print(line) + lw = line.split(',') + S = [float(lw[3]), float(lw[5])/10000, float(lw[6])/10000] + if len(sources_all) == 0: + S.append(0.00001) + S.append(0.00001) + S.append(0.00001) + + else: + last = sources_all[-1] + S.append((S[0] - last[0])/last[0]) + S.append((S[1] - last[1])/last[1]) + S.append((S[2] - last[2])/last[2]) + + + sources_all.append(S) + + T = S[3] + targets_all.append(T) + + +print(len(sources_all)) +print(len(targets_all)) + + +sources = sources_all[:5000] +targets = targets_all[:5000] +sources_test = sources_all[5000:] +targets_test = targets_all[5000:] + + +''' +path = get_file('nietzsche.txt', origin="https://s3.amazonaws.com/text-datasets/nietzsche.txt") +text = open(path).read().lower() +print('corpus length:', len(text)) + +chars = sorted(list(set(text))) +print('total chars:', len(chars)) +char_indices = dict((c, i) for i, c in enumerate(chars)) +indices_char = dict((i, c) for i, c in enumerate(chars)) +''' + +# cut the text in semi-redundant sequences of maxlen characters +maxlen = 40 +step = 1 +sentences = [] +next_chars = [] +for i in range(0, len(sources) - maxlen, step): + sentences.append(sources[i: i + maxlen]) + next_chars.append(targets[i + maxlen]) +print('nb sequences:', len(sentences)) + +sentences_test = [] +next_chars_test = [] +for i in range(0, len(sources_test) - maxlen, step): + sentences_test.append(sources_test[i: i + maxlen]) + next_chars_test.append(targets_test[i + maxlen]) +print('nb test sequences:', len(sentences_test)) + +print('Vectorization...') +X = np.zeros((len(sentences), maxlen, 6), dtype=np.float32) +y = np.zeros((len(sentences), 1), dtype=np.float32) +for i, sentence in enumerate(sentences): + for t, char in enumerate(sentence): + for g in xrange(0,6): + X[i, t, g] = char[g] + y[i, 0] = next_chars[i] + + +# build the model: 2 stacked LSTM +print('Build model...') +model = Sequential() +model.add(LSTM(1024, return_sequences=True, input_shape=(maxlen, 6))) +model.add(LSTM(1024, return_sequences=False)) +model.add(Dropout(0.2)) +model.add(Dense(1)) +#model.add(Dense(1)) +#model.add(Activation('softmax')) +model.add(Activation('linear')) + +model.compile(loss='mse', optimizer='rmsprop') + +''' +def sample(a, temperature=1.0): + # helper function to sample an index from a probability array + a = np.log(a) / temperature + a = np.exp(a) / np.sum(np.exp(a)) + return np.argmax(np.random.multinomial(1, a, 1)) +''' + +# train the model, output generated text after each iteration +for iteration in range(1, 60): + print() + print('-' * 50) + print('Iteration', iteration) + model.fit(X, y, batch_size=128, nb_epoch=1) + + + predret = [] + for sent,targ in zip(sentences_test[:20], targets_test[:20]): + x = np.zeros((1, maxlen, 6)) + for t, char in enumerate(sent): + for g in xrange(0,6): + x[0, t, g] = char[g] + + preds = model.predict(x, verbose=0)[0] + print(preds[0], targ) + + + +print("cbf done!") diff --git a/material/stub b/material/stub new file mode 100644 index 0000000..e69de29 diff --git a/readme.md b/readme.md new file mode 100755 index 0000000..e89c7be --- /dev/null +++ b/readme.md @@ -0,0 +1,21 @@ +# Neural Conversational Model in Torch + +Forked from https://github.com/chenb67/neuralconvo + +如果训练中遇到问题请先翻原始fork的issue,这里只是改成中文而已! + + + +## How +Use https://github.com/dgkae/dgk_lost_conv as training corpus. The chinese sentenses should be splited by semantic words, using '/'. We modify cornell_movie_dialog.lua to support it. Lua save all string(e.g. chinese) all in multibyte, so in chinese the formal pl.lexer is not working. We use outsider word-splitting tool and using '/' as the tag. + +## Result + +![result](a.png) +![result2](b.png) + + + + +## Rwt +本repo已不在维护,有几个聊天群: diff --git a/run_server.sh b/run_server.sh new file mode 100755 index 0000000..7132c25 --- /dev/null +++ b/run_server.sh @@ -0,0 +1 @@ +th -i eval-server.lua --cuda diff --git a/tokenizer.lua b/tokenizer.lua old mode 100644 new mode 100755 index 071f0b0..5ba657f --- a/tokenizer.lua +++ b/tokenizer.lua @@ -27,10 +27,15 @@ local function endpunct(token) end local function unknown(token) + print("unknown") return yield("unknown", token) end function M.tokenize(text) + + print(text) + + --{ "^[\128-\193]+", word }, return lexer.scan(text, { { "^%s+", space }, { "^['\"]", quote }, diff --git a/train.lua b/train.lua old mode 100644 new mode 100755 index 0399464..912cb1c --- a/train.lua +++ b/train.lua @@ -12,7 +12,7 @@ cmd:option('--learningRate', 0.05, 'learning rate at t=0') cmd:option('--momentum', 0.9, 'momentum') cmd:option('--minLR', 0.00001, 'minimum learning rate') cmd:option('--saturateEpoch', 20, 'epoch at which linear decayed LR will reach minLR') -cmd:option('--maxEpoch', 50, 'maximum number of epochs to run') +cmd:option('--maxEpoch', 30, 'maximum number of epochs to run') cmd:option('--batchSize', 10, 'number of examples to load at once') cmd:text() @@ -63,6 +63,8 @@ end -- Run the experiment +print("dgk ending") +--exit() for epoch = 1, options.maxEpoch do print("\n-- Epoch " .. epoch .. " / " .. options.maxEpoch)