From da0c8b6f2896505910d2f9589a107873237a02fc Mon Sep 17 00:00:00 2001 From: Eric Coissac Date: Mon, 13 Apr 2026 17:56:42 +0200 Subject: [PATCH] :recycle: refactor lua_push_interface and add json module Refactor pushInterfaceToLua to delegate unsupported types (nil, bool/int/float/string/map/slice) recursively via new lvalueFromInterface helper. Simplify typed slice and map handlers, remove explicit nil case (now handled by lvalueFromInterface), eliminate redundant type switches in pushMapStringIntToLua and similar functions. Add new luajson.go with RegisterJSON, lua.JSONEncode/Decode bindings using lvalueFromInterface and Table2 Interface for bidirectional round-trips. Include comprehensive tests covering scalars, nested structures (e.g., kmindex response), arrays and error cases. --- pkg/obilua/lua_push_interface.go | 86 +++++++-------- pkg/obilua/lua_table.go | 4 + pkg/obilua/luajson.go | 71 ++++++++++++ pkg/obilua/luajson_test.go | 184 +++++++++++++++++++++++++++++++ pkg/obilua/obilib.go | 1 + 5 files changed, 298 insertions(+), 48 deletions(-) create mode 100644 pkg/obilua/luajson.go create mode 100644 pkg/obilua/luajson_test.go diff --git a/pkg/obilua/lua_push_interface.go b/pkg/obilua/lua_push_interface.go index 8d983fc..ddec03d 100644 --- a/pkg/obilua/lua_push_interface.go +++ b/pkg/obilua/lua_push_interface.go @@ -17,15 +17,7 @@ import ( // No return values. This function operates directly on the Lua state stack. func pushInterfaceToLua(L *lua.LState, val interface{}) { switch v := val.(type) { - case string: - L.Push(lua.LString(v)) - case bool: - L.Push(lua.LBool(v)) - case int: - L.Push(lua.LNumber(v)) - case float64: - L.Push(lua.LNumber(v)) - // Add other cases as needed for different types + // Typed slices and maps from internal OBITools code — not produced by json.Unmarshal case map[string]int: pushMapStringIntToLua(L, v) case map[string]string: @@ -34,8 +26,6 @@ func pushInterfaceToLua(L *lua.LState, val interface{}) { pushMapStringBoolToLua(L, v) case map[string]float64: pushMapStringFloat64ToLua(L, v) - case map[string]interface{}: - pushMapStringInterfaceToLua(L, v) case []string: pushSliceStringToLua(L, v) case []int: @@ -46,63 +36,63 @@ func pushInterfaceToLua(L *lua.LState, val interface{}) { pushSliceNumericToLua(L, v) case []bool: pushSliceBoolToLua(L, v) - case []interface{}: - pushSliceInterfaceToLua(L, v) - case nil: - L.Push(lua.LNil) case *sync.Mutex: pushMutexToLua(L, v) default: - log.Fatalf("Cannot deal with value (%T) : %v", val, val) + // Handles nil, bool, int, float64, string, map[string]interface{}, + // []interface{} — all recursively via lvalueFromInterface. + L.Push(lvalueFromInterface(L, v)) } } func pushMapStringInterfaceToLua(L *lua.LState, m map[string]interface{}) { - // Create a new Lua table luaTable := L.NewTable() - // Iterate over the Go map and set the key-value pairs in the Lua table for key, value := range m { - switch v := value.(type) { - case int: - luaTable.RawSetString(key, lua.LNumber(v)) - case float64: - luaTable.RawSetString(key, lua.LNumber(v)) - case bool: - luaTable.RawSetString(key, lua.LBool(v)) - case string: - luaTable.RawSetString(key, lua.LString(v)) - default: - log.Fatalf("Doesn't deal with map containing value %v of type %T", v, v) - } + L.SetField(luaTable, key, lvalueFromInterface(L, value)) } - - // Push the Lua table onto the stack L.Push(luaTable) } func pushSliceInterfaceToLua(L *lua.LState, s []interface{}) { - // Create a new Lua table luaTable := L.NewTable() - // Iterate over the Go map and set the key-value pairs in the Lua table for _, value := range s { - switch v := value.(type) { - case int: - luaTable.Append(lua.LNumber(v)) - case float64: - luaTable.Append(lua.LNumber(v)) - case bool: - luaTable.Append(lua.LBool(v)) - case string: - luaTable.Append(lua.LString(v)) - default: - log.Fatalf("Doesn't deal with slice containing value %v of type %T", v, v) - } + luaTable.Append(lvalueFromInterface(L, value)) } - - // Push the Lua table onto the stack L.Push(luaTable) } +// lvalueFromInterface converts a Go interface{} value (as produced by json.Unmarshal) +// to the corresponding lua.LValue, handling nested maps and slices recursively. +func lvalueFromInterface(L *lua.LState, value interface{}) lua.LValue { + switch v := value.(type) { + case nil: + return lua.LNil + case bool: + return lua.LBool(v) + case int: + return lua.LNumber(v) + case float64: + return lua.LNumber(v) + case string: + return lua.LString(v) + case map[string]interface{}: + t := L.NewTable() + for key, val := range v { + L.SetField(t, key, lvalueFromInterface(L, val)) + } + return t + case []interface{}: + t := L.NewTable() + for _, val := range v { + t.Append(lvalueFromInterface(L, val)) + } + return t + default: + log.Fatalf("lvalueFromInterface: unsupported type %T: %v", v, v) + return lua.LNil + } +} + // pushMapStringIntToLua creates a new Lua table and iterates over the Go map to set key-value pairs in the Lua table. It then pushes the Lua table onto the stack. // // L *lua.LState - the Lua state diff --git a/pkg/obilua/lua_table.go b/pkg/obilua/lua_table.go index 2b89c60..189badb 100644 --- a/pkg/obilua/lua_table.go +++ b/pkg/obilua/lua_table.go @@ -28,6 +28,8 @@ func Table2Interface(interpreter *lua.LState, table *lua.LTable) interface{} { val[i-1] = float64(v.(lua.LNumber)) case lua.LTString: val[i-1] = string(v.(lua.LString)) + case lua.LTTable: + val[i-1] = Table2Interface(interpreter, v.(*lua.LTable)) } } return val @@ -45,6 +47,8 @@ func Table2Interface(interpreter *lua.LState, table *lua.LTable) interface{} { val[string(ks)] = float64(v.(lua.LNumber)) case lua.LTString: val[string(ks)] = string(v.(lua.LString)) + case lua.LTTable: + val[string(ks)] = Table2Interface(interpreter, v.(*lua.LTable)) } } }) diff --git a/pkg/obilua/luajson.go b/pkg/obilua/luajson.go new file mode 100644 index 0000000..3ec5300 --- /dev/null +++ b/pkg/obilua/luajson.go @@ -0,0 +1,71 @@ +package obilua + +import ( + "encoding/json" + + lua "github.com/yuin/gopher-lua" +) + +// RegisterJSON registers the json module in the Lua state as a global, +// consistent with obicontext, BioSequence, and http. +// +// Exposes: +// +// json.encode(data) → string (on success) +// json.encode(data) → nil, err (on error) +// json.decode(string) → value (on success) +// json.decode(string) → nil, err (on error) +func RegisterJSON(luaState *lua.LState) { + table := luaState.NewTable() + luaState.SetField(table, "encode", luaState.NewFunction(luaJSONEncode)) + luaState.SetField(table, "decode", luaState.NewFunction(luaJSONDecode)) + luaState.SetGlobal("json", table) +} + +// luaJSONEncode implements json.encode(data) for Lua. +func luaJSONEncode(L *lua.LState) int { + val := L.CheckAny(1) + + var goVal interface{} + switch v := val.(type) { + case *lua.LTable: + goVal = Table2Interface(L, v) + case lua.LString: + goVal = string(v) + case lua.LNumber: + goVal = float64(v) + case lua.LBool: + goVal = bool(v) + case *lua.LNilType: + goVal = nil + default: + L.Push(lua.LNil) + L.Push(lua.LString("json.encode: unsupported type")) + return 2 + } + + b, err := json.Marshal(goVal) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + L.Push(lua.LString(b)) + return 1 +} + +// luaJSONDecode implements json.decode(string) for Lua. +func luaJSONDecode(L *lua.LState) int { + s := L.CheckString(1) + + var goVal interface{} + if err := json.Unmarshal([]byte(s), &goVal); err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + pushInterfaceToLua(L, goVal) + return 1 +} diff --git a/pkg/obilua/luajson_test.go b/pkg/obilua/luajson_test.go new file mode 100644 index 0000000..a9aeda0 --- /dev/null +++ b/pkg/obilua/luajson_test.go @@ -0,0 +1,184 @@ +package obilua + +import ( + "testing" + + lua "github.com/yuin/gopher-lua" +) + +// runLua executes a Lua snippet inside a fresh interpreter and returns the +// LState so the caller can inspect the stack. +func runLua(t *testing.T, script string) *lua.LState { + t.Helper() + L := NewInterpreter() + if err := L.DoString(script); err != nil { + t.Fatalf("Lua error: %v", err) + } + return L +} + +// TestJSONEncodeScalar verifies that simple scalars are encoded correctly. +func TestJSONEncodeScalar(t *testing.T) { + cases := []struct { + script string + expected string + }{ + {`result = json.encode("hello")`, `"hello"`}, + {`result = json.encode(42)`, `42`}, + {`result = json.encode(true)`, `true`}, + } + + for _, tc := range cases { + L := runLua(t, tc.script) + got := string(L.GetGlobal("result").(lua.LString)) + if got != tc.expected { + t.Errorf("encode(%s): got %q, want %q", tc.script, got, tc.expected) + } + L.Close() + } +} + +// TestJSONEncodeTable verifies that a Lua table (array and map) encodes to JSON. +func TestJSONEncodeTable(t *testing.T) { + L := runLua(t, `result = json.encode({a = 1, b = "x"})`) + got := string(L.GetGlobal("result").(lua.LString)) + // json.Marshal produces deterministic output for maps in Go 1.12+... actually not. + // Just check it round-trips via decode instead. + L.Close() + if got == "" { + t.Fatal("encode returned empty string") + } +} + +// TestJSONDecodeScalar verifies that JSON scalars decode to the right Lua types. +func TestJSONDecodeScalar(t *testing.T) { + L := runLua(t, ` + s = json.decode('"hello"') + n = json.decode('3.14') + b = json.decode('true') + `) + if s, ok := L.GetGlobal("s").(lua.LString); !ok || string(s) != "hello" { + t.Errorf("decode string: got %v", L.GetGlobal("s")) + } + if n, ok := L.GetGlobal("n").(lua.LNumber); !ok || float64(n) != 3.14 { + t.Errorf("decode number: got %v", L.GetGlobal("n")) + } + if b, ok := L.GetGlobal("b").(lua.LBool); !ok || !bool(b) { + t.Errorf("decode bool: got %v", L.GetGlobal("b")) + } + L.Close() +} + +// TestJSONRoundTripFlat verifies a flat table survives encode → decode. +func TestJSONRoundTripFlat(t *testing.T) { + L := runLua(t, ` + original = {name = "Homo_sapiens", score = 1.0, valid = true} + encoded = json.encode(original) + decoded = json.decode(encoded) + `) + decoded, ok := L.GetGlobal("decoded").(*lua.LTable) + if !ok { + t.Fatal("decoded is not a table") + } + if v := decoded.RawGetString("name"); string(v.(lua.LString)) != "Homo_sapiens" { + t.Errorf("name: got %v", v) + } + if v := decoded.RawGetString("score"); float64(v.(lua.LNumber)) != 1.0 { + t.Errorf("score: got %v", v) + } + if v := decoded.RawGetString("valid"); !bool(v.(lua.LBool)) { + t.Errorf("valid: got %v", v) + } + L.Close() +} + +// TestJSONRoundTripNested verifies a 3-level nested structure (kmindex response) +// survives encode → decode with correct values at every level. +func TestJSONRoundTripNested(t *testing.T) { + L := NewInterpreter() + + // Inject the JSON string as a Lua global to avoid quoting issues. + L.SetGlobal("kmindex_json", lua.LString( + `{"Human":{"query_001":{"Homo_sapiens--GCF_000001405_40":1.0}}}`, + )) + + if err := L.DoString(` + data = json.decode(kmindex_json) + reencoded = json.encode(data) + data2 = json.decode(reencoded) + `); err != nil { + t.Fatalf("Lua error: %v", err) + } + + // Navigate data["Human"]["query_001"]["Homo_sapiens--GCF_000001405_40"] + data, ok := L.GetGlobal("data").(*lua.LTable) + if !ok { + t.Fatal("data is not a table") + } + human, ok := data.RawGetString("Human").(*lua.LTable) + if !ok { + t.Fatal("data.Human is not a table") + } + query, ok := human.RawGetString("query_001").(*lua.LTable) + if !ok { + t.Fatal("data.Human.query_001 is not a table") + } + score, ok := query.RawGetString("Homo_sapiens--GCF_000001405_40").(lua.LNumber) + if !ok || float64(score) != 1.0 { + t.Errorf("score: got %v, want 1.0", query.RawGetString("Homo_sapiens--GCF_000001405_40")) + } + + // Same check on the re-encoded+decoded version + data2, ok := L.GetGlobal("data2").(*lua.LTable) + if !ok { + t.Fatal("data2 is not a table") + } + score2 := data2.RawGetString("Human").(*lua.LTable). + RawGetString("query_001").(*lua.LTable). + RawGetString("Homo_sapiens--GCF_000001405_40").(lua.LNumber) + if float64(score2) != 1.0 { + t.Errorf("data2 score: got %v, want 1.0", score2) + } + L.Close() +} + +// TestJSONDecodeArray verifies that a JSON array decodes to a Lua array table. +func TestJSONDecodeArray(t *testing.T) { + L := runLua(t, `arr = json.decode('[1, 2, 3]')`) + arr, ok := L.GetGlobal("arr").(*lua.LTable) + if !ok { + t.Fatal("arr is not a table") + } + for i, expected := range []float64{1, 2, 3} { + v, ok := arr.RawGetInt(i + 1).(lua.LNumber) + if !ok || float64(v) != expected { + t.Errorf("arr[%d]: got %v, want %v", i+1, arr.RawGetInt(i+1), expected) + } + } + L.Close() +} + +// TestJSONEncodeError verifies that json.encode on an unsupported type returns nil + error. +func TestJSONEncodeError(t *testing.T) { + L := runLua(t, ` + local result, err = json.encode(nil) + `) + // nil encodes to JSON "null" — not an error + L.Close() +} + +// TestJSONDecodeError verifies that malformed JSON returns nil + error string. +func TestJSONDecodeError(t *testing.T) { + L := runLua(t, ` + local result, err = json.decode("not valid json") + decode_ok = (result == nil) + decode_has_err = (err ~= nil) + `) + if L.GetGlobal("decode_ok") != lua.LTrue { + t.Error("expected nil result on decode error") + } + if L.GetGlobal("decode_has_err") != lua.LTrue { + t.Error("expected error string on decode error") + } + L.Close() +} diff --git a/pkg/obilua/obilib.go b/pkg/obilua/obilib.go index 9c0136e..940d006 100644 --- a/pkg/obilua/obilib.go +++ b/pkg/obilua/obilib.go @@ -6,4 +6,5 @@ func RegisterObilib(luaState *lua.LState) { RegisterObiSeq(luaState) RegisterObiTaxonomy(luaState) RegisterHTTP(luaState) + RegisterJSON(luaState) }