diff --git a/fst/_fst.pyx.tpl b/fst/_fst.pyx.tpl index 41a5abd..4e8ef84 100644 --- a/fst/_fst.pyx.tpl +++ b/fst/_fst.pyx.tpl @@ -1,9 +1,13 @@ +# distutils: language = c++ +# distutils: extra_compile_args = -O2 -std=c++11 + cimport libfst cimport sym import subprocess import random import re +from libcpp cimport bool as boolean from libcpp.vector cimport vector from libcpp.string cimport string from libcpp.pair cimport pair @@ -88,6 +92,15 @@ cdef class SymbolTable: """table.write(filename): save the symbol table to filename""" self.table.Write(as_str(filename)) + def to_bytes(self): + """table.to_bytes(): binary representation of the symbol table to a Python bytes object""" + + cdef ostringstream out + result = self.table.WriteStream(out) + cdef bytes out_str = out.str() + + return out_str + def find(self, key): """table.find(int value) -> decoded symbol if any symbol maps to this value table.find(str symbol) -> encoded value if this symbol is in the table""" @@ -442,12 +455,31 @@ cdef class {{fst}}(_Fst): self.fst.SetInputSymbols(self.isyms.table) if keep_osyms and self.osyms is not None: self.fst.SetOutputSymbols(self.osyms.table) - result = self.fst.Write(as_str(filename)) + result = self.fst.Write(as_str(filename)) # reset symbols: self.fst.SetInputSymbols(NULL) self.fst.SetOutputSymbols(NULL) return result + def to_bytes(self, keep_isyms=False, keep_osyms=False): + """fst.to_bytes(): return the binary representation of the transducer as a bytes objects""" + if keep_isyms and self.isyms is not None: + self.fst.SetInputSymbols(self.isyms.table) + if keep_osyms and self.osyms is not None: + self.fst.SetOutputSymbols(self.osyms.table) + + cdef libfst.FstWriteOptions* options = new libfst.FstWriteOptions("", True, keep_isyms, keep_osyms) + cdef ostringstream out + result = self.fst.WriteStream(out, options[0]) + del options + cdef bytes out_str = out.str() + + # reset symbols: + self.fst.SetInputSymbols(NULL) + self.fst.SetOutputSymbols(NULL) + + return out_str + property input_deterministic: def __get__(self): return bool(self.fst.Properties(libfst.kIDeterministic, True) & @@ -593,12 +625,12 @@ cdef class {{fst}}(_Fst): dist = [{{weight}}(distances[i].Value()) for i in range(distances.size())] return dist - def shortest_path(self, unsigned n=1): + def shortest_path(self, unsigned n=1, bint unique=False, bint first_path=False): """fst.shortest_path(int n=1) -> transducer containing the n shortest paths""" if not isinstance(self, StdVectorFst): raise TypeError('Weight needs to have the path property and be right distributive') cdef {{fst}} result = {{fst}}(isyms=self.isyms, osyms=self.osyms) - libfst.ShortestPath(self.fst[0], result.fst, n) + libfst.ShortestPath(self.fst[0], result.fst, n, unique, first_path) return result def push(self, final=False, weights=False, labels=False): diff --git a/fst/libfst.pxd.tpl b/fst/libfst.pxd.tpl index 1ef1b34..0a30ec8 100644 --- a/fst/libfst.pxd.tpl +++ b/fst/libfst.pxd.tpl @@ -1,3 +1,7 @@ +# distutils: language = c++ +# distutils: libraries = fst +# distutils: extra_compile_args = -O2 -std=c++11 + from libcpp.vector cimport vector from libcpp.string cimport string from libcpp.pair cimport pair @@ -30,11 +34,15 @@ cdef extern from "" namespace "fst": void Next() Arc& Value() + cdef cppclass FstWriteOptions: + FstWriteOptions(const string& src, bint hdr, bint isym, bint osym) + cdef cppclass Fst: int Start() unsigned NumArcs(int s) Fst* Copy() bint Write(string& filename) + bint WriteStream "Write" (ostream& strm, const FstWriteOptions& opts) uint64_t Properties(uint64_t mask, bint compute) cdef cppclass ExpandedFst(Fst): @@ -142,7 +150,7 @@ cdef extern from "" namespace "fst": cdef void Difference(Fst &ifst1, Fst &ifst2, MutableFst* ofst) cdef void Intersect(Fst &ifst1, Fst &ifst2, MutableFst* ofst) cdef void Reverse(Fst &ifst, MutableFst* ofst) - cdef void ShortestPath(Fst &ifst, MutableFst* ofst, unsigned n) + cdef void ShortestPath(Fst &ifst, MutableFst* ofst, unsigned n, bint unique, bint first_path) cdef void ArcMap (Fst &ifst, MutableFst* ofst, ArcMapper mapper) {{#types}} cdef void ShortestDistance(Fst &fst, vector[{{weight}}]* distance, bint reverse) diff --git a/fst/sym.pxd b/fst/sym.pxd index 27c638f..96733d9 100644 --- a/fst/sym.pxd +++ b/fst/sym.pxd @@ -1,5 +1,5 @@ from libcpp.string cimport string -from util cimport istream +from util cimport istream, ostream cdef extern from "" namespace "fst": cdef cppclass SymbolTable: @@ -9,7 +9,7 @@ cdef extern from "" namespace "fst": long AddSymbol(string &symbol) string& Name() bint Write(string &filename) - #WriteText (ostream &strm) + bint WriteStream "Write" (ostream &strm) string Find(long key) long Find(const char* symbol) unsigned NumSymbols() diff --git a/setup.py b/setup.py index 7d816de..f3428b5 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,17 @@ +import glob +import logging +import subprocess import sys import os +import distutils +from distutils.cmd import Command from distutils.core import setup from distutils.extension import Extension +import pystache +import yaml +from Cython.Build import cythonize + INC, LIB = [], [] # MacPorts @@ -10,15 +19,67 @@ INC.append('/opt/local/include') LIB.append('/opt/local/lib') +# Homebrew +if sys.platform == 'darwin' and os.path.isdir('/usr/local/lib'): + INC.append('/usr/local/include') + LIB.append('/usr/local/lib') + + ext_modules = [ Extension(name='fst._fst', - sources=['fst/_fst.cpp'], - libraries=['fst'], - extra_compile_args=['-O2'], - include_dirs=INC, - library_dirs=LIB) + sources=['fst/_fst.pyx.tpl', 'fst/libfst.pxd.tpl', 'fst/types.yml'], + include_dirs=INC, + library_dirs=LIB) ] + +def pystache_render(template_filename, context_filenames, rendered_filename): + with open(rendered_filename, 'w') as rendered: + with open(template_filename) as template: + contexts = [] + for context_filename in context_filenames: + with open(context_filename) as context: + contexts += list(yaml.load_all(context)) + + rendered.write(pystache.render(template.read(), context=contexts[0])) + + +def mustache_cmd_render(template_filename, context_filenames, rendered_filename, mustache_cmd='mustache'): + cmd = 'cat {yamls} | {mustache_cmd} - {template} > {rendered}'.format(yamls=' '.join(context_filenames), + template=template_filename, + rendered=rendered_filename, + mustache_cmd=mustache_cmd) + logging.info('Running command: %s' % str(cmd)) + subprocess.check_call(cmd, shell=True) + + +def mustachize(modules, mustache_command=''): + """Run command.""" + for module in modules: + sources = module.sources + + context_filenames = [source for source in sources if os.path.splitext(source)[-1] in {'.yml'}] + templates = [source for source in sources if os.path.splitext(source)[-1] in {'.tpl'}] + other = set(sources) - set(context_filenames) - set(templates) + + new_sources = [] + for template in templates: + rendered, new_ext = os.path.splitext(template) + + if mustache_command: + mustache_cmd_render(template, context_filenames, rendered, mustache_cmd=mustache_command) + else: + pystache_render(template, context_filenames, rendered) + + _, new_ext = os.path.splitext(rendered) + if new_ext in {'.pyx'}: + new_sources.append(rendered) + + module.sources[:] = list(other) + new_sources + + return modules + + long_description = """ pyfst ===== @@ -59,5 +120,5 @@ 'Intended Audience :: Education', 'Intended Audience :: Science/Research'], packages=['fst'], - ext_modules=ext_modules + ext_modules=cythonize(mustachize(ext_modules)), )