diff --git a/pkg/obilua/lua_push_interface.go b/pkg/obilua/lua_push_interface.go index 8d983fc..ddec03d 100644 --- a/pkg/obilua/lua_push_interface.go +++ b/pkg/obilua/lua_push_interface.go @@ -17,15 +17,7 @@ import ( // No return values. This function operates directly on the Lua state stack. func pushInterfaceToLua(L *lua.LState, val interface{}) { switch v := val.(type) { - case string: - L.Push(lua.LString(v)) - case bool: - L.Push(lua.LBool(v)) - case int: - L.Push(lua.LNumber(v)) - case float64: - L.Push(lua.LNumber(v)) - // Add other cases as needed for different types + // Typed slices and maps from internal OBITools code — not produced by json.Unmarshal case map[string]int: pushMapStringIntToLua(L, v) case map[string]string: @@ -34,8 +26,6 @@ func pushInterfaceToLua(L *lua.LState, val interface{}) { pushMapStringBoolToLua(L, v) case map[string]float64: pushMapStringFloat64ToLua(L, v) - case map[string]interface{}: - pushMapStringInterfaceToLua(L, v) case []string: pushSliceStringToLua(L, v) case []int: @@ -46,63 +36,63 @@ func pushInterfaceToLua(L *lua.LState, val interface{}) { pushSliceNumericToLua(L, v) case []bool: pushSliceBoolToLua(L, v) - case []interface{}: - pushSliceInterfaceToLua(L, v) - case nil: - L.Push(lua.LNil) case *sync.Mutex: pushMutexToLua(L, v) default: - log.Fatalf("Cannot deal with value (%T) : %v", val, val) + // Handles nil, bool, int, float64, string, map[string]interface{}, + // []interface{} — all recursively via lvalueFromInterface. + L.Push(lvalueFromInterface(L, v)) } } func pushMapStringInterfaceToLua(L *lua.LState, m map[string]interface{}) { - // Create a new Lua table luaTable := L.NewTable() - // Iterate over the Go map and set the key-value pairs in the Lua table for key, value := range m { - switch v := value.(type) { - case int: - luaTable.RawSetString(key, lua.LNumber(v)) - case float64: - luaTable.RawSetString(key, lua.LNumber(v)) - case bool: - luaTable.RawSetString(key, lua.LBool(v)) - case string: - luaTable.RawSetString(key, lua.LString(v)) - default: - log.Fatalf("Doesn't deal with map containing value %v of type %T", v, v) - } + L.SetField(luaTable, key, lvalueFromInterface(L, value)) } - - // Push the Lua table onto the stack L.Push(luaTable) } func pushSliceInterfaceToLua(L *lua.LState, s []interface{}) { - // Create a new Lua table luaTable := L.NewTable() - // Iterate over the Go map and set the key-value pairs in the Lua table for _, value := range s { - switch v := value.(type) { - case int: - luaTable.Append(lua.LNumber(v)) - case float64: - luaTable.Append(lua.LNumber(v)) - case bool: - luaTable.Append(lua.LBool(v)) - case string: - luaTable.Append(lua.LString(v)) - default: - log.Fatalf("Doesn't deal with slice containing value %v of type %T", v, v) - } + luaTable.Append(lvalueFromInterface(L, value)) } - - // Push the Lua table onto the stack L.Push(luaTable) } +// lvalueFromInterface converts a Go interface{} value (as produced by json.Unmarshal) +// to the corresponding lua.LValue, handling nested maps and slices recursively. +func lvalueFromInterface(L *lua.LState, value interface{}) lua.LValue { + switch v := value.(type) { + case nil: + return lua.LNil + case bool: + return lua.LBool(v) + case int: + return lua.LNumber(v) + case float64: + return lua.LNumber(v) + case string: + return lua.LString(v) + case map[string]interface{}: + t := L.NewTable() + for key, val := range v { + L.SetField(t, key, lvalueFromInterface(L, val)) + } + return t + case []interface{}: + t := L.NewTable() + for _, val := range v { + t.Append(lvalueFromInterface(L, val)) + } + return t + default: + log.Fatalf("lvalueFromInterface: unsupported type %T: %v", v, v) + return lua.LNil + } +} + // pushMapStringIntToLua creates a new Lua table and iterates over the Go map to set key-value pairs in the Lua table. It then pushes the Lua table onto the stack. // // L *lua.LState - the Lua state diff --git a/pkg/obilua/lua_table.go b/pkg/obilua/lua_table.go index 2b89c60..189badb 100644 --- a/pkg/obilua/lua_table.go +++ b/pkg/obilua/lua_table.go @@ -28,6 +28,8 @@ func Table2Interface(interpreter *lua.LState, table *lua.LTable) interface{} { val[i-1] = float64(v.(lua.LNumber)) case lua.LTString: val[i-1] = string(v.(lua.LString)) + case lua.LTTable: + val[i-1] = Table2Interface(interpreter, v.(*lua.LTable)) } } return val @@ -45,6 +47,8 @@ func Table2Interface(interpreter *lua.LState, table *lua.LTable) interface{} { val[string(ks)] = float64(v.(lua.LNumber)) case lua.LTString: val[string(ks)] = string(v.(lua.LString)) + case lua.LTTable: + val[string(ks)] = Table2Interface(interpreter, v.(*lua.LTable)) } } }) diff --git a/pkg/obilua/luajson.go b/pkg/obilua/luajson.go new file mode 100644 index 0000000..3ec5300 --- /dev/null +++ b/pkg/obilua/luajson.go @@ -0,0 +1,71 @@ +package obilua + +import ( + "encoding/json" + + lua "github.com/yuin/gopher-lua" +) + +// RegisterJSON registers the json module in the Lua state as a global, +// consistent with obicontext, BioSequence, and http. +// +// Exposes: +// +// json.encode(data) → string (on success) +// json.encode(data) → nil, err (on error) +// json.decode(string) → value (on success) +// json.decode(string) → nil, err (on error) +func RegisterJSON(luaState *lua.LState) { + table := luaState.NewTable() + luaState.SetField(table, "encode", luaState.NewFunction(luaJSONEncode)) + luaState.SetField(table, "decode", luaState.NewFunction(luaJSONDecode)) + luaState.SetGlobal("json", table) +} + +// luaJSONEncode implements json.encode(data) for Lua. +func luaJSONEncode(L *lua.LState) int { + val := L.CheckAny(1) + + var goVal interface{} + switch v := val.(type) { + case *lua.LTable: + goVal = Table2Interface(L, v) + case lua.LString: + goVal = string(v) + case lua.LNumber: + goVal = float64(v) + case lua.LBool: + goVal = bool(v) + case *lua.LNilType: + goVal = nil + default: + L.Push(lua.LNil) + L.Push(lua.LString("json.encode: unsupported type")) + return 2 + } + + b, err := json.Marshal(goVal) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + L.Push(lua.LString(b)) + return 1 +} + +// luaJSONDecode implements json.decode(string) for Lua. +func luaJSONDecode(L *lua.LState) int { + s := L.CheckString(1) + + var goVal interface{} + if err := json.Unmarshal([]byte(s), &goVal); err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + pushInterfaceToLua(L, goVal) + return 1 +} diff --git a/pkg/obilua/luajson_test.go b/pkg/obilua/luajson_test.go new file mode 100644 index 0000000..a9aeda0 --- /dev/null +++ b/pkg/obilua/luajson_test.go @@ -0,0 +1,184 @@ +package obilua + +import ( + "testing" + + lua "github.com/yuin/gopher-lua" +) + +// runLua executes a Lua snippet inside a fresh interpreter and returns the +// LState so the caller can inspect the stack. +func runLua(t *testing.T, script string) *lua.LState { + t.Helper() + L := NewInterpreter() + if err := L.DoString(script); err != nil { + t.Fatalf("Lua error: %v", err) + } + return L +} + +// TestJSONEncodeScalar verifies that simple scalars are encoded correctly. +func TestJSONEncodeScalar(t *testing.T) { + cases := []struct { + script string + expected string + }{ + {`result = json.encode("hello")`, `"hello"`}, + {`result = json.encode(42)`, `42`}, + {`result = json.encode(true)`, `true`}, + } + + for _, tc := range cases { + L := runLua(t, tc.script) + got := string(L.GetGlobal("result").(lua.LString)) + if got != tc.expected { + t.Errorf("encode(%s): got %q, want %q", tc.script, got, tc.expected) + } + L.Close() + } +} + +// TestJSONEncodeTable verifies that a Lua table (array and map) encodes to JSON. +func TestJSONEncodeTable(t *testing.T) { + L := runLua(t, `result = json.encode({a = 1, b = "x"})`) + got := string(L.GetGlobal("result").(lua.LString)) + // json.Marshal produces deterministic output for maps in Go 1.12+... actually not. + // Just check it round-trips via decode instead. + L.Close() + if got == "" { + t.Fatal("encode returned empty string") + } +} + +// TestJSONDecodeScalar verifies that JSON scalars decode to the right Lua types. +func TestJSONDecodeScalar(t *testing.T) { + L := runLua(t, ` + s = json.decode('"hello"') + n = json.decode('3.14') + b = json.decode('true') + `) + if s, ok := L.GetGlobal("s").(lua.LString); !ok || string(s) != "hello" { + t.Errorf("decode string: got %v", L.GetGlobal("s")) + } + if n, ok := L.GetGlobal("n").(lua.LNumber); !ok || float64(n) != 3.14 { + t.Errorf("decode number: got %v", L.GetGlobal("n")) + } + if b, ok := L.GetGlobal("b").(lua.LBool); !ok || !bool(b) { + t.Errorf("decode bool: got %v", L.GetGlobal("b")) + } + L.Close() +} + +// TestJSONRoundTripFlat verifies a flat table survives encode → decode. +func TestJSONRoundTripFlat(t *testing.T) { + L := runLua(t, ` + original = {name = "Homo_sapiens", score = 1.0, valid = true} + encoded = json.encode(original) + decoded = json.decode(encoded) + `) + decoded, ok := L.GetGlobal("decoded").(*lua.LTable) + if !ok { + t.Fatal("decoded is not a table") + } + if v := decoded.RawGetString("name"); string(v.(lua.LString)) != "Homo_sapiens" { + t.Errorf("name: got %v", v) + } + if v := decoded.RawGetString("score"); float64(v.(lua.LNumber)) != 1.0 { + t.Errorf("score: got %v", v) + } + if v := decoded.RawGetString("valid"); !bool(v.(lua.LBool)) { + t.Errorf("valid: got %v", v) + } + L.Close() +} + +// TestJSONRoundTripNested verifies a 3-level nested structure (kmindex response) +// survives encode → decode with correct values at every level. +func TestJSONRoundTripNested(t *testing.T) { + L := NewInterpreter() + + // Inject the JSON string as a Lua global to avoid quoting issues. + L.SetGlobal("kmindex_json", lua.LString( + `{"Human":{"query_001":{"Homo_sapiens--GCF_000001405_40":1.0}}}`, + )) + + if err := L.DoString(` + data = json.decode(kmindex_json) + reencoded = json.encode(data) + data2 = json.decode(reencoded) + `); err != nil { + t.Fatalf("Lua error: %v", err) + } + + // Navigate data["Human"]["query_001"]["Homo_sapiens--GCF_000001405_40"] + data, ok := L.GetGlobal("data").(*lua.LTable) + if !ok { + t.Fatal("data is not a table") + } + human, ok := data.RawGetString("Human").(*lua.LTable) + if !ok { + t.Fatal("data.Human is not a table") + } + query, ok := human.RawGetString("query_001").(*lua.LTable) + if !ok { + t.Fatal("data.Human.query_001 is not a table") + } + score, ok := query.RawGetString("Homo_sapiens--GCF_000001405_40").(lua.LNumber) + if !ok || float64(score) != 1.0 { + t.Errorf("score: got %v, want 1.0", query.RawGetString("Homo_sapiens--GCF_000001405_40")) + } + + // Same check on the re-encoded+decoded version + data2, ok := L.GetGlobal("data2").(*lua.LTable) + if !ok { + t.Fatal("data2 is not a table") + } + score2 := data2.RawGetString("Human").(*lua.LTable). + RawGetString("query_001").(*lua.LTable). + RawGetString("Homo_sapiens--GCF_000001405_40").(lua.LNumber) + if float64(score2) != 1.0 { + t.Errorf("data2 score: got %v, want 1.0", score2) + } + L.Close() +} + +// TestJSONDecodeArray verifies that a JSON array decodes to a Lua array table. +func TestJSONDecodeArray(t *testing.T) { + L := runLua(t, `arr = json.decode('[1, 2, 3]')`) + arr, ok := L.GetGlobal("arr").(*lua.LTable) + if !ok { + t.Fatal("arr is not a table") + } + for i, expected := range []float64{1, 2, 3} { + v, ok := arr.RawGetInt(i + 1).(lua.LNumber) + if !ok || float64(v) != expected { + t.Errorf("arr[%d]: got %v, want %v", i+1, arr.RawGetInt(i+1), expected) + } + } + L.Close() +} + +// TestJSONEncodeError verifies that json.encode on an unsupported type returns nil + error. +func TestJSONEncodeError(t *testing.T) { + L := runLua(t, ` + local result, err = json.encode(nil) + `) + // nil encodes to JSON "null" — not an error + L.Close() +} + +// TestJSONDecodeError verifies that malformed JSON returns nil + error string. +func TestJSONDecodeError(t *testing.T) { + L := runLua(t, ` + local result, err = json.decode("not valid json") + decode_ok = (result == nil) + decode_has_err = (err ~= nil) + `) + if L.GetGlobal("decode_ok") != lua.LTrue { + t.Error("expected nil result on decode error") + } + if L.GetGlobal("decode_has_err") != lua.LTrue { + t.Error("expected error string on decode error") + } + L.Close() +} diff --git a/pkg/obilua/obilib.go b/pkg/obilua/obilib.go index 9c0136e..940d006 100644 --- a/pkg/obilua/obilib.go +++ b/pkg/obilua/obilib.go @@ -6,4 +6,5 @@ func RegisterObilib(luaState *lua.LState) { RegisterObiSeq(luaState) RegisterObiTaxonomy(luaState) RegisterHTTP(luaState) + RegisterJSON(luaState) }