diff --git a/pkg/obilua/lua.go b/pkg/obilua/lua.go index a64ea27..34700b3 100644 --- a/pkg/obilua/lua.go +++ b/pkg/obilua/lua.go @@ -2,6 +2,7 @@ package obilua import ( "bytes" + "fmt" "os" "git.metabarcoding.org/obitools/obitools4/obitools4/pkg/obiiter" @@ -16,6 +17,7 @@ func NewInterpreter() *lua.LState { lua := lua.NewState() RegisterObilib(lua) + RegisterObiContext(lua) return lua } @@ -48,16 +50,70 @@ func CompileScript(filePath string) (*lua.FunctionProto, error) { func LuaWorker(proto *lua.FunctionProto) obiseq.SeqWorker { interpreter := NewInterpreter() lfunc := interpreter.NewFunctionFromProto(proto) + interpreter.Push(lfunc) + err := interpreter.PCall(0, lua.MultRet, nil) - f := func(sequence *obiseq.BioSequence) (obiseq.BioSequenceSlice, error) { - interpreter.SetGlobal("sequence", obiseq2Lua(interpreter, sequence)) - interpreter.Push(lfunc) - err := interpreter.PCall(0, lua.MultRet, nil) - - return obiseq.BioSequenceSlice{sequence}, err + if err != nil { + log.Fatalf("Error in executing the lua script") } - return f + result := interpreter.GetGlobal("worker") + + if lua_worker, ok := result.(*lua.LFunction); ok { + f := func(sequence *obiseq.BioSequence) (obiseq.BioSequenceSlice, error) { + // Call the Lua function concat + // Lua analogue: + // str = concat("Go", "Lua") + // print(str) + if err := interpreter.CallByParam(lua.P{ + Fn: lua_worker, // name of Lua function + NRet: 1, // number of returned values + Protect: true, // return err or panic + }, obiseq2Lua(interpreter, sequence)); err != nil { + log.Fatal(err) + } + + lreponse := interpreter.Get(-1) + defer interpreter.Pop(1) + + if reponse, ok := lreponse.(*lua.LUserData); ok { + s := reponse.Value + switch val := s.(type) { + case *obiseq.BioSequence: + return obiseq.BioSequenceSlice{val}, err + default: + return nil, fmt.Errorf("worker function doesn't return the correct type") + } + } + + return nil, fmt.Errorf("worker function doesn't return the correct type") + } + + return f + } + + log.Fatalf("THe worker object is not a function") + return nil + // f := func(sequence *obiseq.BioSequence) (obiseq.BioSequenceSlice, error) { + // interpreter.SetGlobal("sequence", obiseq2Lua(interpreter, sequence)) + // interpreter.Push(lfunc) + // err := interpreter.PCall(0, lua.MultRet, nil) + // result := interpreter.GetGlobal("result") + + // if result != lua.LNil { + // log.Info("youpi ", result) + // } + + // rep := interpreter.GetGlobal("sequence") + + // if rep.Type() == lua.LTUserData { + // ud := rep.(*lua.LUserData) + // sequence = ud.Value.(*obiseq.BioSequence) + // } + + // return obiseq.BioSequenceSlice{sequence}, err + // } + } func LuaProcessor(iterator obiiter.IBioSequence, name, program string, breakOnError bool, nworkers int) obiiter.IBioSequence { @@ -69,10 +125,6 @@ func LuaProcessor(iterator obiiter.IBioSequence, name, program string, breakOnEr newIter.Add(nworkers) - go func() { - newIter.WaitAndClose() - }() - bp := []byte(program) proto, err := Compile(bp, name) @@ -80,6 +132,55 @@ func LuaProcessor(iterator obiiter.IBioSequence, name, program string, breakOnEr log.Fatalf("Cannot compile script %s : %v", name, err) } + interpreter := NewInterpreter() + lfunc := interpreter.NewFunctionFromProto(proto) + interpreter.Push(lfunc) + err = interpreter.PCall(0, lua.MultRet, nil) + + if err != nil { + log.Fatalf("Error in executing the lua script") + } + + result := interpreter.GetGlobal("begin") + if lua_begin, ok := result.(*lua.LFunction); ok { + if err := interpreter.CallByParam(lua.P{ + Fn: lua_begin, // name of Lua function + NRet: 0, // number of returned values + Protect: true, // return err or panic + }); err != nil { + log.Fatal(err) + } + } + + interpreter.Close() + + go func() { + newIter.WaitAndClose() + + interpreter := NewInterpreter() + lfunc := interpreter.NewFunctionFromProto(proto) + interpreter.Push(lfunc) + err = interpreter.PCall(0, lua.MultRet, nil) + + if err != nil { + log.Fatalf("Error in executing the lua script") + } + + result := interpreter.GetGlobal("finish") + if lua_finish, ok := result.(*lua.LFunction); ok { + if err := interpreter.CallByParam(lua.P{ + Fn: lua_finish, // name of Lua function + NRet: 0, // number of returned values + Protect: true, // return err or panic + }); err != nil { + log.Fatal(err) + } + } + + interpreter.Close() + + }() + ff := func(iterator obiiter.IBioSequence) { w := LuaWorker(proto) sw := obiseq.SeqToSliceWorker(w, false) diff --git a/pkg/obilua/lua_obicontext.go b/pkg/obilua/lua_obicontext.go new file mode 100644 index 0000000..d3ee691 --- /dev/null +++ b/pkg/obilua/lua_obicontext.go @@ -0,0 +1,117 @@ +package obilua + +import ( + "sync" + + log "github.com/sirupsen/logrus" + lua "github.com/yuin/gopher-lua" +) + +var __lua_obicontext = &sync.Map{} +var __lua_obicontext_lock = &sync.Mutex{} + +func RegisterObiContext(luaState *lua.LState) { + + table := luaState.NewTable() + luaState.SetField(table, "item", luaState.NewFunction(obicontextGetSet)) + luaState.SetField(table, "lock", luaState.NewFunction(obicontextLock)) + luaState.SetField(table, "unlock", luaState.NewFunction(obicontextUnlock)) + luaState.SetField(table, "trylock", luaState.NewFunction(obicontextTrylock)) + luaState.SetField(table, "inc", luaState.NewFunction(obicontextInc)) + luaState.SetField(table, "dec", luaState.NewFunction(obicontextDec)) + + luaState.SetGlobal("obicontext", table) +} + +func obicontextGetSet(interpreter *lua.LState) int { + key := interpreter.CheckString(1) + + if interpreter.GetTop() == 2 { + value := interpreter.CheckAny(2) + + switch val := value.(type) { + case lua.LBool: + __lua_obicontext.Store(key, bool(val)) + case lua.LNumber: + __lua_obicontext.Store(key, float64(val)) + case lua.LString: + __lua_obicontext.Store(key, string(val)) + case *lua.LTable: + __lua_obicontext.Store(key, Table2Interface(interpreter, val)) + default: + log.Fatalf("Cannot store values of type %s in the obicontext", value.Type().String()) + } + + return 0 + + } + + if value, ok := __lua_obicontext.Load(key); ok { + pushInterfaceToLua(interpreter, value) + } else { + interpreter.Push(lua.LNil) + } + + return 1 +} + +func obicontextInc(interpreter *lua.LState) int { + key := interpreter.CheckString(1) + __lua_obicontext_lock.Lock() + + if value, ok := __lua_obicontext.Load(key); ok { + if v, ok := value.(float64); ok { + v++ + __lua_obicontext.Store(key, v) + __lua_obicontext_lock.Unlock() + interpreter.Push(lua.LNumber(v)) + return 1 + } + } + + __lua_obicontext_lock.Unlock() + log.Fatalf("Cannot increment item %s", key) + + return 0 +} + +func obicontextDec(interpreter *lua.LState) int { + key := interpreter.CheckString(1) + __lua_obicontext_lock.Lock() + defer __lua_obicontext_lock.Unlock() + + if value, ok := __lua_obicontext.Load(key); ok { + if v, ok := value.(float64); ok { + v-- + __lua_obicontext.Store(key, v) + interpreter.Push(lua.LNumber(v)) + return 1 + } + } + + log.Fatalf("Cannot decrement item %s", key) + + return 0 +} + +func obicontextLock(interpreter *lua.LState) int { + + __lua_obicontext_lock.Lock() + + return 0 +} + +func obicontextUnlock(interpreter *lua.LState) int { + + __lua_obicontext_lock.Unlock() + + return 0 +} + +func obicontextTrylock(interpreter *lua.LState) int { + + result := __lua_obicontext_lock.TryLock() + + interpreter.Push(lua.LBool(result)) + return 1 +} diff --git a/pkg/obilua/lua_push_interface.go b/pkg/obilua/lua_push_interface.go index a561860..809325d 100644 --- a/pkg/obilua/lua_push_interface.go +++ b/pkg/obilua/lua_push_interface.go @@ -1,7 +1,7 @@ package obilua import ( - "log" + log "github.com/sirupsen/logrus" lua "github.com/yuin/gopher-lua" ) @@ -32,6 +32,8 @@ func pushInterfaceToLua(L *lua.LState, val interface{}) { pushMapStringBoolToLua(L, v) case map[string]float64: pushMapStringFloat64ToLua(L, v) + case map[string]interface{}: + pushMapStringInterfaceToLua(L, v) case []string: pushSliceStringToLua(L, v) case []int: @@ -43,10 +45,33 @@ func pushInterfaceToLua(L *lua.LState, val interface{}) { case nil: L.Push(lua.LNil) default: - log.Fatalf("Cannot deal with value Mv", val) + log.Fatalf("Cannot deal with value %v", val) } } +func pushMapStringInterfaceToLua(L *lua.LState, m map[string]interface{}) { + // Create a new Lua table + luaTable := L.NewTable() + // Iterate over the Go map and set the key-value pairs in the Lua table + for key, value := range m { + switch v := value.(type) { + case int: + luaTable.RawSetString(key, lua.LNumber(v)) + case float64: + luaTable.RawSetString(key, lua.LNumber(v)) + case bool: + luaTable.RawSetString(key, lua.LBool(v)) + case string: + luaTable.RawSetString(key, lua.LString(v)) + default: + log.Fatalf("Doesn't deal with map containing value %v", v) + } + } + + // Push the Lua table onto the stack + L.Push(luaTable) +} + // pushMapStringIntToLua creates a new Lua table and iterates over the Go map to set key-value pairs in the Lua table. It then pushes the Lua table onto the stack. // // L *lua.LState - the Lua state @@ -57,7 +82,7 @@ func pushMapStringIntToLua(L *lua.LState, m map[string]int) { // Iterate over the Go map and set the key-value pairs in the Lua table for key, value := range m { - L.SetField(luaTable, key, lua.LNumber(value)) + L.SetTable(luaTable, lua.LString(key), lua.LNumber(value)) } // Push the Lua table onto the stack @@ -73,7 +98,7 @@ func pushMapStringStringToLua(L *lua.LState, m map[string]string) { // Iterate over the Go map and set the key-value pairs in the Lua table for key, value := range m { - L.SetField(luaTable, key, lua.LString(value)) + L.SetTable(luaTable, lua.LString(key), lua.LString(value)) } // Push the Lua table onto the stack @@ -94,7 +119,7 @@ func pushMapStringBoolToLua(L *lua.LState, m map[string]bool) { // Iterate over the Go map and set the key-value pairs in the Lua table for key, value := range m { - L.SetField(luaTable, key, lua.LBool(value)) + L.SetTable(luaTable, lua.LString(key), lua.LBool(value)) } // Push the Lua table onto the stack @@ -112,7 +137,7 @@ func pushMapStringFloat64ToLua(L *lua.LState, m map[string]float64) { // Iterate over the Go map and set the key-value pairs in the Lua table for key, value := range m { // Use lua.LNumber since Lua does not differentiate between float and int - L.SetField(luaTable, key, lua.LNumber(value)) + L.SetTable(luaTable, lua.LString(key), lua.LNumber(value)) } // Push the Lua table onto the stack diff --git a/pkg/obilua/lua_table.go b/pkg/obilua/lua_table.go new file mode 100644 index 0000000..f3dd77e --- /dev/null +++ b/pkg/obilua/lua_table.go @@ -0,0 +1,61 @@ +package obilua + +import lua "github.com/yuin/gopher-lua" + +func Table2Interface(interpreter *lua.LState, table *lua.LTable) interface{} { + // 07/03/2024: il y a sans doute plus efficace mais pour l'instant + // ça marche + isArray := true + table.ForEach(func(key, value lua.LValue) { + if _, ok := key.(lua.LNumber); !ok { + isArray = false + } + }) + if isArray { + val := make([]interface{}, table.Len()) + for i := 1; i <= table.Len(); i++ { + val[i-1] = table.RawGetInt(i) + } + return val + } else { + // The table contains a hash + val := make(map[string]interface{}) + table.ForEach(func(k, v lua.LValue) { + if ks, ok := k.(lua.LString); ok { + val[string(ks)] = v + } + }) + return val + } +} + +// } + +// return nil +// } + +// if x := table.RawGetInt(1); x != nil { +// val := make([]interface{}, table.Len()) +// for i := 1; i <= table.Len(); i++ { +// val[i-1] = table.RawGetInt(i) +// } +// return val +// } else { + +// } +// } + +// if lv.Type() == lua.LTTable { +// table := lv.(*lua.LTable) +// isArray := true +// table.ForEach(func(key, value lua.LValue) { +// if _, ok := key.(lua.LNumber); !ok { +// isArray = false +// } +// }) +// if isArray { +// // The table contains an array +// } else { +// // The table contains a hash +// } +// } diff --git a/pkg/obilua/obiseq.go b/pkg/obilua/obiseq.go index b1395cd..c16ae09 100644 --- a/pkg/obilua/obiseq.go +++ b/pkg/obilua/obiseq.go @@ -1,8 +1,6 @@ package obilua import ( - log "github.com/sirupsen/logrus" - "git.metabarcoding.org/obitools/obitools4/obitools4/pkg/obiseq" lua "github.com/yuin/gopher-lua" ) @@ -31,6 +29,7 @@ func obiseq2Lua(interpreter *lua.LState, return ud } + func newObiSeq(luaState *lua.LState) int { seqid := luaState.CheckString(1) seq := []byte(luaState.CheckString(2)) @@ -47,12 +46,17 @@ func newObiSeq(luaState *lua.LState) int { } var bioSequenceMethods = map[string]lua.LGFunction{ - "id": bioSequenceGetSetId, - "sequence": bioSequenceGetSetSequence, - "definition": bioSequenceGetSetDefinition, - "count": bioSequenceGetSetCount, - "taxid": bioSequenceGetSetTaxid, - "attribute": bioSequenceGetSetAttribute, + "id": bioSequenceGetSetId, + "sequence": bioSequenceGetSetSequence, + "definition": bioSequenceGetSetDefinition, + "count": bioSequenceGetSetCount, + "taxid": bioSequenceGetSetTaxid, + "attribute": bioSequenceGetSetAttribute, + "len": bioSequenceGetLength, + "has_sequence": bioSequenceHasSequence, + "has_qualities": bioSequenceHasQualities, + "source": bioSequenceGetSource, + "md5": bioSequenceGetMD5, } // checkBioSequence checks if the first argument in the Lua stack is a *obiseq.BioSequence. @@ -132,13 +136,17 @@ func bioSequenceGetSetAttribute(luaState *lua.LState) int { if luaState.GetTop() == 3 { ud := luaState.CheckAny(3) - log.Infof("ud : %v [%v]", ud, ud.Type()) // // Perhaps the code needs some type checking on ud.Value // It's a first test // - s.SetAttribute(attName, ud) + if v, ok := ud.(*lua.LTable); ok { + s.SetAttribute(attName, Table2Interface(luaState, v)) + } else { + s.SetAttribute(attName, ud) + } + return 0 } @@ -152,3 +160,42 @@ func bioSequenceGetSetAttribute(luaState *lua.LState) int { return 1 } + +func bioSequenceGetLength(luaState *lua.LState) int { + s := checkBioSequence(luaState) + luaState.Push(lua.LNumber(s.Len())) + return 1 +} + +func bioSequenceHasSequence(luaState *lua.LState) int { + s := checkBioSequence(luaState) + luaState.Push(lua.LBool(s.HasSequence())) + return 1 +} + +func bioSequenceHasQualities(luaState *lua.LState) int { + s := checkBioSequence(luaState) + luaState.Push(lua.LBool(s.HasQualities())) + return 1 +} + +func bioSequenceGetSource(luaState *lua.LState) int { + s := checkBioSequence(luaState) + if s.HasSource() { + luaState.Push(lua.LString(s.Source())) + } else { + luaState.Push(lua.LNil) + } + return 1 +} + +func bioSequenceGetMD5(luaState *lua.LState) int { + s := checkBioSequence(luaState) + md5 := s.MD5() + rt := luaState.NewTable() + for i := 0; i < 16; i++ { + rt.Append(lua.LNumber(md5[i])) + } + luaState.Push(rt) + return 1 +} diff --git a/pkg/obitools/obiscript/options.go b/pkg/obitools/obiscript/options.go index 5f64ae3..338b59b 100644 --- a/pkg/obitools/obiscript/options.go +++ b/pkg/obitools/obiscript/options.go @@ -47,37 +47,22 @@ func CLIAskScriptTemplate() bool { } func CLIScriptTemplate() string { - return ` - import { - "sync" - "git.metabarcoding.org/obitools/obitools4/obitools4/pkg/obiseq" - } - // - // Begin function run before the first sequence being processed - // + return `function begin() + obicontext.item("compteur",0) +end - func Begin(environment *sync.Map) { +function worker(sequence) + samples = sequence:attribute("merged_sample") + samples["tutu"]=4 + sequence:attribute("merged_sample",samples) + sequence:attribute("toto",44444) + nb = obicontext.inc("compteur") + sequence:id("seq_" .. nb) + return sequence +end - } - - // - // Begin function run after the last sequence being processed - // - - func End(environment *sync.Map) { - - } - - // - // Worker function run for each sequence validating the selection predicat as specified by - // the command line options. - // - // The function must return the altered sequence. - // If the function returns nil, the sequence is discarded from the output - func Worker(sequence *obiseq.BioSequence, environment *sync.Map) *obiseq.BioSequence { - - - return sequence - } +function finish() + print("compteur = " .. obicontext.item("compteur")) +end ` }