From 2173938eb08ed35d4dc3eea86ddabcfc01fe5b9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petri=20H=C3=A4kkinen?= Date: Fri, 15 Dec 2023 01:05:51 +0200 Subject: [PATCH] Add tagged lightuserdata (#1087) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change adds support for tagged lightuserdata and optional custom typenames for lightuserdata. Background: Lightuserdata is an efficient representation for many kinds of unmanaged handles and resources in a game engine. However, currently the VM only supports one kind of lightuserdata, which makes it problematic in practice. For example, it's not possible to distinguish between different kinds of lightuserdata in Lua bindings, which can lead to unsafe practices and even crashes when a wrong kind of lightuserdata is passed to a binding function. Tagged lightuserdata work similarly to tagged userdata, i.e. they allow checking the tag quickly using lua_tolightuserdatatagged (or lua_lightuserdatatag). The tag is stored in the 'extra' field of TValue so it will add no cost to the (untagged) lightuserdata type. Alternatives would be to use full userdata values or use bitpacking to embed type information into lightuserdata on application level. Unfortunately these options are not that great in practice: full userdata have major performance implications and bitpacking fails in cases where full 64 bits are already used (e.g. pointers or 64-bit hashes). Lightuserdata names are not strictly necessary but they are rather convenient when debugging Lua code. More precise error messages and tostring returning more specific typename are useful to have in practice (e.g. "resource" or "entity" instead of the more generic "userdata"). Impl note: I did not add support for renaming tags in lua_setlightuserdataname as I'm not sure if it's possible to free fixed strings. If it's simple enough, maybe we should allow renaming (although I can't think of a specific need for it)? --------- Co-authored-by: Petri Häkkinen --- CodeGen/src/CodeGenUtils.cpp | 8 ++++---- VM/include/lua.h | 8 +++++++- VM/include/luaconf.h | 5 +++++ VM/src/lapi.cpp | 37 ++++++++++++++++++++++++++++++++++-- VM/src/lobject.cpp | 4 ++-- VM/src/lobject.h | 10 +++++++++- VM/src/lstate.cpp | 2 ++ VM/src/lstate.h | 2 ++ VM/src/ltm.cpp | 12 ++++++++++++ VM/src/lvmexecute.cpp | 16 +++++++++------- VM/src/lvmutils.cpp | 2 +- tests/Conformance.test.cpp | 27 ++++++++++++++++++++++++++ 12 files changed, 115 insertions(+), 18 deletions(-) diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index c1a9c3380..973829ca0 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -71,7 +71,7 @@ bool forgLoopTableIter(lua_State* L, Table* h, int index, TValue* ra) if (!ttisnil(e)) { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1)), LU_TAG_ITERATOR); setnvalue(ra + 3, double(index + 1)); setobj2s(L, ra + 4, e); @@ -90,7 +90,7 @@ bool forgLoopTableIter(lua_State* L, Table* h, int index, TValue* ra) if (!ttisnil(gval(n))) { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1)), LU_TAG_ITERATOR); getnodekey(L, ra + 3, n); setobj(L, ra + 4, gval(n)); @@ -115,7 +115,7 @@ bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra) if (!ttisnil(gval(n))) { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1)), LU_TAG_ITERATOR); getnodekey(L, ra + 3, n); setobj(L, ra + 4, gval(n)); @@ -697,7 +697,7 @@ const Instruction* executeFORGPREP(lua_State* L, const Instruction* pc, StkId ba { // set up registers for builtin iteration setobj2s(L, ra + 1, ra); - setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0)), LU_TAG_ITERATOR); setnilvalue(ra); } else diff --git a/VM/include/lua.h b/VM/include/lua.h index dbf19cb4c..0390de7cf 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -159,9 +159,11 @@ LUA_API const char* lua_namecallatom(lua_State* L, int* atom); LUA_API int lua_objlen(lua_State* L, int idx); LUA_API lua_CFunction lua_tocfunction(lua_State* L, int idx); LUA_API void* lua_tolightuserdata(lua_State* L, int idx); +LUA_API void* lua_tolightuserdatatagged(lua_State* L, int idx, int tag); LUA_API void* lua_touserdata(lua_State* L, int idx); LUA_API void* lua_touserdatatagged(lua_State* L, int idx, int tag); LUA_API int lua_userdatatag(lua_State* L, int idx); +LUA_API int lua_lightuserdatatag(lua_State* L, int idx); LUA_API lua_State* lua_tothread(lua_State* L, int idx); LUA_API void* lua_tobuffer(lua_State* L, int idx, size_t* len); LUA_API const void* lua_topointer(lua_State* L, int idx); @@ -186,7 +188,7 @@ LUA_API void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debug LUA_API void lua_pushboolean(lua_State* L, int b); LUA_API int lua_pushthread(lua_State* L); -LUA_API void lua_pushlightuserdata(lua_State* L, void* p); +LUA_API void lua_pushlightuserdatatagged(lua_State* L, void* p, int tag); LUA_API void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag); LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); @@ -323,6 +325,9 @@ typedef void (*lua_Destructor)(lua_State* L, void* userdata); LUA_API void lua_setuserdatadtor(lua_State* L, int tag, lua_Destructor dtor); LUA_API lua_Destructor lua_getuserdatadtor(lua_State* L, int tag); +LUA_API void lua_setlightuserdataname(lua_State* L, int tag, const char* name); +LUA_API const char* lua_getlightuserdataname(lua_State* L, int tag); + LUA_API void lua_clonefunction(lua_State* L, int idx); LUA_API void lua_cleartable(lua_State* L, int idx); @@ -370,6 +375,7 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_pushliteral(L, s) lua_pushlstring(L, "" s, (sizeof(s) / sizeof(char)) - 1) #define lua_pushcfunction(L, fn, debugname) lua_pushcclosurek(L, fn, debugname, 0, NULL) #define lua_pushcclosure(L, fn, debugname, nup) lua_pushcclosurek(L, fn, debugname, nup, NULL) +#define lua_pushlightuserdata(L, p) lua_pushlightuserdatatagged(L, p, 0) #define lua_setglobal(L, s) lua_setfield(L, LUA_GLOBALSINDEX, (s)) #define lua_getglobal(L, s) lua_getfield(L, LUA_GLOBALSINDEX, (s)) diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index 7a1bbb950..910e259a4 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -101,6 +101,11 @@ #define LUA_UTAG_LIMIT 128 #endif +// number of valid Lua lightuserdata tags +#ifndef LUA_LUTAG_LIMIT +#define LUA_LUTAG_LIMIT 128 +#endif + // upper bound for number of size classes used by page allocator #ifndef LUA_SIZECLASSES #define LUA_SIZECLASSES 32 diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 355e4e21c..58c767f16 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -505,6 +505,12 @@ void* lua_tolightuserdata(lua_State* L, int idx) return (!ttislightuserdata(o)) ? NULL : pvalue(o); } +void* lua_tolightuserdatatagged(lua_State* L, int idx, int tag) +{ + StkId o = index2addr(L, idx); + return (!ttislightuserdata(o) || lightuserdatatag(o) != tag) ? NULL : pvalue(o); +} + void* lua_touserdata(lua_State* L, int idx) { StkId o = index2addr(L, idx); @@ -530,6 +536,14 @@ int lua_userdatatag(lua_State* L, int idx) return -1; } +int lua_lightuserdatatag(lua_State* L, int idx) +{ + StkId o = index2addr(L, idx); + if (ttislightuserdata(o)) + return lightuserdatatag(o); + return -1; +} + lua_State* lua_tothread(lua_State* L, int idx) { StkId o = index2addr(L, idx); @@ -665,9 +679,10 @@ void lua_pushboolean(lua_State* L, int b) api_incr_top(L); } -void lua_pushlightuserdata(lua_State* L, void* p) +void lua_pushlightuserdatatagged(lua_State* L, void* p, int tag) { - setpvalue(L->top, p); + api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT); + setpvalue(L->top, p, tag); api_incr_top(L); } @@ -1412,6 +1427,24 @@ lua_Destructor lua_getuserdatadtor(lua_State* L, int tag) return L->global->udatagc[tag]; } +void lua_setlightuserdataname(lua_State* L, int tag, const char* name) +{ + api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT); + api_check(L, !L->global->lightuserdataname[tag]); // renaming not supported + if (!L->global->lightuserdataname[tag]) + { + L->global->lightuserdataname[tag] = luaS_new(L, name); + luaS_fix(L->global->lightuserdataname[tag]); // never collect these names + } +} + +const char* lua_getlightuserdataname(lua_State* L, int tag) +{ + api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT); + const TString* name = L->global->lightuserdataname[tag]; + return name ? getstr(name) : nullptr; +} + void lua_clonefunction(lua_State* L, int idx) { luaC_checkGC(L); diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index 88d8d7ca9..081e33147 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -48,7 +48,7 @@ int luaO_rawequalObj(const TValue* t1, const TValue* t2) case LUA_TBOOLEAN: return bvalue(t1) == bvalue(t2); // boolean true must be 1 !! case LUA_TLIGHTUSERDATA: - return pvalue(t1) == pvalue(t2); + return pvalue(t1) == pvalue(t2) && (!FFlag::TaggedLuData || lightuserdatatag(t1) == lightuserdatatag(t2)); default: LUAU_ASSERT(iscollectable(t1)); return gcvalue(t1) == gcvalue(t2); @@ -71,7 +71,7 @@ int luaO_rawequalKey(const TKey* t1, const TValue* t2) case LUA_TBOOLEAN: return bvalue(t1) == bvalue(t2); // boolean true must be 1 !! case LUA_TLIGHTUSERDATA: - return pvalue(t1) == pvalue(t2); + return pvalue(t1) == pvalue(t2) && (!FFlag::TaggedLuData || lightuserdatatag(t1) == lightuserdatatag(t2)); default: LUAU_ASSERT(iscollectable(t1)); return gcvalue(t1) == gcvalue(t2); diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 716401402..d236f7e43 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -80,6 +80,11 @@ typedef struct lua_TValue #define l_isfalse(o) (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0)) +#define lightuserdatatag(o) check_exp(ttislightuserdata(o), (o)->extra[0]) + +// Internal tags used by the VM +#define LU_TAG_ITERATOR LUA_UTAG_LIMIT + /* ** for internal debug only */ @@ -120,10 +125,11 @@ typedef struct lua_TValue } #endif -#define setpvalue(obj, x) \ +#define setpvalue(obj, x, tag) \ { \ TValue* i_o = (obj); \ i_o->value.p = (x); \ + i_o->extra[0] = (tag); \ i_o->tt = LUA_TLIGHTUSERDATA; \ } @@ -492,3 +498,5 @@ LUAI_FUNC int luaO_str2d(const char* s, double* result); LUAI_FUNC const char* luaO_pushvfstring(lua_State* L, const char* fmt, va_list argp); LUAI_FUNC const char* luaO_pushfstring(lua_State* L, const char* fmt, ...); LUAI_FUNC const char* luaO_chunkid(char* buf, size_t buflen, const char* source, size_t srclen); + +LUAU_FASTFLAG(TaggedLuData) diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 161dcda04..858f61a3a 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -210,6 +210,8 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->mt[i] = NULL; for (i = 0; i < LUA_UTAG_LIMIT; i++) g->udatagc[i] = NULL; + for (i = 0; i < LUA_LUTAG_LIMIT; i++) + g->lightuserdataname[i] = NULL; for (i = 0; i < LUA_MEMORY_CATEGORIES; i++) g->memcatbytes[i] = 0; diff --git a/VM/src/lstate.h b/VM/src/lstate.h index ed73d3d8e..2c6d35dc5 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -214,6 +214,8 @@ typedef struct global_State void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory + TString* lightuserdataname[LUA_LUTAG_LIMIT]; // names for tagged lightuserdata + GCStats gcstats; #ifdef LUAI_GCMETRICS diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 927a535be..3a9fddaa3 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -129,6 +129,18 @@ const TString* luaT_objtypenamestr(lua_State* L, const TValue* o) if (ttisstring(type)) return tsvalue(type); } + else if (FFlag::TaggedLuData && ttislightuserdata(o)) + { + int tag = lightuserdatatag(o); + + if (unsigned(tag) < LUA_LUTAG_LIMIT) + { + const TString* name = L->global->lightuserdataname[tag]; + + if (name) + return name; + } + } else if (Table* mt = L->global->mt[ttype(o)]) { const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index c1a3ca8e2..1c77fa14e 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -135,6 +135,8 @@ // Does VM support native execution via ExecutionCallbacks? We mostly assume it does but keep the define to make it easy to quantify the cost. #define VM_HAS_NATIVE 1 +LUAU_FASTFLAGVARIABLE(TaggedLuData, false) + LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) { ptrdiff_t base = savestack(L, L->base); @@ -1110,7 +1112,7 @@ static void luau_execute(lua_State* L) VM_NEXT(); case LUA_TLIGHTUSERDATA: - pc += pvalue(ra) == pvalue(rb) ? LUAU_INSN_D(insn) : 1; + pc += (pvalue(ra) == pvalue(rb) && (!FFlag::TaggedLuData || lightuserdatatag(ra) == lightuserdatatag(rb))) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); VM_NEXT(); @@ -1225,7 +1227,7 @@ static void luau_execute(lua_State* L) VM_NEXT(); case LUA_TLIGHTUSERDATA: - pc += pvalue(ra) != pvalue(rb) ? LUAU_INSN_D(insn) : 1; + pc += (pvalue(ra) != pvalue(rb) || (FFlag::TaggedLuData && lightuserdatatag(ra) != lightuserdatatag(rb))) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); VM_NEXT(); @@ -2296,7 +2298,7 @@ static void luau_execute(lua_State* L) { // set up registers for builtin iteration setobj2s(L, ra + 1, ra); - setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0)), LU_TAG_ITERATOR); setnilvalue(ra); } else @@ -2348,7 +2350,7 @@ static void luau_execute(lua_State* L) if (!ttisnil(e)) { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1)), LU_TAG_ITERATOR); setnvalue(ra + 3, double(index + 1)); setobj2s(L, ra + 4, e); @@ -2369,7 +2371,7 @@ static void luau_execute(lua_State* L) if (!ttisnil(gval(n))) { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1)), LU_TAG_ITERATOR); getnodekey(L, ra + 3, n); setobj2s(L, ra + 4, gval(n)); @@ -2421,7 +2423,7 @@ static void luau_execute(lua_State* L) { setnilvalue(ra); // ra+1 is already the table - setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0)), LU_TAG_ITERATOR); } else if (!ttisfunction(ra)) { @@ -2450,7 +2452,7 @@ static void luau_execute(lua_State* L) { setnilvalue(ra); // ra+1 is already the table - setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0)), LU_TAG_ITERATOR); } else if (!ttisfunction(ra)) { diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index c4b0b47d8..851d778c0 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -288,7 +288,7 @@ int luaV_equalval(lua_State* L, const TValue* t1, const TValue* t2) case LUA_TBOOLEAN: return bvalue(t1) == bvalue(t2); // true must be 1 !! case LUA_TLIGHTUSERDATA: - return pvalue(t1) == pvalue(t2); + return pvalue(t1) == pvalue(t2) && (!FFlag::TaggedLuData || lightuserdatatag(t1) == lightuserdatatag(t2)); case LUA_TUSERDATA: { tm = get_compTM(L, uvalue(t1)->metatable, uvalue(t2)->metatable, TM_EQ); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index a9c5bc373..b530ce55b 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -32,6 +32,7 @@ LUAU_FASTFLAG(LuauBufferDefinitions); LUAU_FASTFLAG(LuauCodeGenFixByteLower); LUAU_FASTFLAG(LuauCompileBufferAnnotation); LUAU_FASTFLAG(LuauLoopInterruptFix); +LUAU_FASTFLAG(TaggedLuData); LUAU_DYNAMIC_FASTFLAG(LuauStricterUtf8); LUAU_FASTINT(CodegenHeuristicsInstructionLimit); @@ -1700,6 +1701,32 @@ TEST_CASE("UserdataApi") CHECK(dtorhits == 42); } +TEST_CASE("LightuserdataApi") +{ + ScopedFastFlag taggedLuData{FFlag::TaggedLuData, true}; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + void* value = (void*)0x12345678; + + lua_pushlightuserdatatagged(L, value, 1); + CHECK(lua_lightuserdatatag(L, -1) == 1); + CHECK(lua_tolightuserdatatagged(L, -1, 0) == nullptr); + CHECK(lua_tolightuserdatatagged(L, -1, 1) == value); + + lua_setlightuserdataname(L, 1, "id"); + CHECK(!lua_getlightuserdataname(L, 0)); + CHECK(strcmp(lua_getlightuserdataname(L, 1), "id") == 0); + CHECK(strcmp(luaL_typename(L, -1), "id") == 0); + + lua_pushlightuserdatatagged(L, value, 0); + lua_pushlightuserdatatagged(L, value, 1); + CHECK(lua_rawequal(L, -1, -2) == 0); + + globalState.reset(); +} + TEST_CASE("Iter") { runConformance("iter.lua");