Cython API: when importing a file in a DMS, its length is computed

beforehand for the progress bar
This commit is contained in:
Celine Mercier
2019-03-13 18:35:32 +01:00
parent 50e7cd61a6
commit d88390c6d8
5 changed files with 97 additions and 48 deletions

View File

@ -87,14 +87,19 @@ def run(config):
DMS.obi_atexit()
logger("info", "obi import : imports file into a DMS")
logger("info", "obi import: imports an object (file(s), obiview, taxonomy...) into a DMS")
entry_count = -1
if not config['obi']['taxdump']:
input = open_uri(config['obi']['inputURI'])
if input is None: # TODO check for bytes instead now?
raise Exception("Could not open input URI")
# TODO uuuuh
entry_count = input[4]
logger("info", "Importing %d entries", entry_count)
# TODO a bit dirty
if input[2]==Nuc_Seq:
v = View_NUC_SEQS
else:
@ -117,7 +122,7 @@ def run(config):
output[0].close()
return
pb = ProgressBar(10000000, config, seconde=5) # TODO should be number of records in file
pb = ProgressBar(entry_count, config, seconde=5)
entries = input[1]
@ -250,13 +255,16 @@ def run(config):
i+=1
pb(i, force=True)
print("", file=sys.stderr)
# Save command config in View and DMS comments
command_line = " ".join(sys.argv[1:])
view.write_config(config, "import", command_line, input_str=[os.path.abspath(config['obi']['inputURI'])])
output[0].record_command_line(command_line)
print("\n")
print(view.__repr__())
#print("\n\nOutput view:\n````````````", file=sys.stderr)
#print(repr(view), file=sys.stderr)
try:
input[0].close()
@ -267,3 +275,4 @@ def run(config):
except AttributeError:
pass
logger("info", "Done.")

View File

@ -23,20 +23,20 @@ def is_ngsfilter_line(line): # TODO doesn't work?
return False
def entryIteratorFactory(lineiterator,
int skip=0,
only=None,
bytes seqtype=b'nuc',
int offset=-1,
bint noquality=False,
bint skiperror=True,
bint header=False,
bytes sep=None,
bytes dec=b'.',
bytes nastring=b"NA",
bint stripwhite=True,
bint blanklineskip=True,
bytes commentchar=b"#",
int buffersize=100000000):
int skip=0,
only=None,
bytes seqtype=b'nuc',
int offset=-1,
bint noquality=False,
bint skiperror=True,
bint header=False,
bytes sep=None,
bytes dec=b'.',
bytes nastring=b"NA",
bint stripwhite=True,
bint blanklineskip=True,
bytes commentchar=b"#",
int buffersize=100000000):
if isinstance(lineiterator, (str, bytes)):
lineiterator=uopen(lineiterator)
@ -65,7 +65,7 @@ def entryIteratorFactory(lineiterator,
format=b"embl"
elif first[0:6]==b'LOCUS ':
format=b"genbank"
elif first[0:11]==b'#@ecopcr-v2': # TODO v2????
elif first[0:8]==b'#@ecopcr':
format=b"ecopcrfile"
elif is_ngsfilter_line(first):
format=b"ngsfilter"
@ -83,7 +83,8 @@ def entryIteratorFactory(lineiterator,
firstline=first,
buffersize=buffersize,
nastring=nastring),
Nuc_Seq)
Nuc_Seq,
format)
else:
raise NotImplementedError()
elif format==b'fastq':
@ -94,7 +95,8 @@ def entryIteratorFactory(lineiterator,
firstline=first,
buffersize=buffersize,
nastring=nastring),
Nuc_Seq)
Nuc_Seq,
format)
elif format==b'tabular':
return (tabIterator(lineiterator,
header = header,
@ -108,7 +110,8 @@ def entryIteratorFactory(lineiterator,
only = only,
firstline=first,
buffersize=buffersize),
dict)
dict,
format)
elif format==b'ngsfilter':
return (ngsfilterIterator(lineiterator,
sep = sep,
@ -121,7 +124,8 @@ def entryIteratorFactory(lineiterator,
only = only,
firstline=first,
buffersize=buffersize),
dict)
dict,
format)
elif format==b'embl':
return (emblIterator(lineiterator,
@ -129,7 +133,8 @@ def entryIteratorFactory(lineiterator,
only=only,
firstline=first,
buffersize=buffersize),
dict)
dict,
format)
raise NotImplementedError('File format not yet implemented')
raise NotImplementedError('File format iterator not implemented yet')

View File

@ -20,7 +20,7 @@ from obitools3.format.fastq import FastqFormat
from obitools3.dms.obiseq import Nuc_Seq
from obitools3.apps.config import getConfiguration,logger
from obitools3.apps.temp import get_temp_dms
from obitools3.utils cimport tobytes # TODO because can't read options as bytes
from obitools3.utils cimport tobytes, count_entries # TODO tobytes because can't read options as bytes
from obitools3.dms.capi.obierrno cimport obi_errno, \
OBIVIEW_ALREADY_EXISTS_ERROR
@ -159,6 +159,7 @@ Reads an URI and returns a tuple containing:
(2) The opened view or iterator on the opened file or writer
(3) The class of object returned or handled by (2)
(4) The original URI in bytes
(5) The number of entries (if input URI) or -1 if unavailable
'''
def open_uri(uri,
bint input=True,
@ -209,7 +210,8 @@ def open_uri(uri,
return (dms[0],
dms[1],
type(dms[1]),
urlunparse(urip))
urlunparse(urip),
len(dms[0]))
try:
resource=open_dms_element(dms[0],
dms[1],
@ -230,7 +232,8 @@ def open_uri(uri,
return (resource[0],
resource[1],
type(resource[1]),
urlunparse(urip))
urlunparse(urip),
len(resource[1]))
except Exception as e:
global obi_errno
if obi_errno == OBIVIEW_ALREADY_EXISTS_ERROR:
@ -503,19 +506,19 @@ def open_uri(uri,
raise NotImplementedError('Output sequence file format not implemented')
else:
if input:
iseq, objclass = entryIteratorFactory(file,
skip, only,
seqtype,
offset,
noquality,
skiperror,
header,
sep,
dec,
nastring,
stripwhite,
blanklineskip,
commentchar)
iseq, objclass, format = entryIteratorFactory(file,
skip, only,
seqtype,
offset,
noquality,
skiperror,
header,
sep,
dec,
nastring,
stripwhite,
blanklineskip,
commentchar)
else: # default export is in fasta? or tab? TODO
objclass = Nuc_Seq # Nuc_Seq_Stored? TODO
iseq = FastaNucWriter(FastaFormat(printNAKeys=printna, NAString=nastring),
@ -525,6 +528,8 @@ def open_uri(uri,
#tmpdms = get_temp_dms()
return (file, iseq, objclass, urib)
entry_count = -1
if input:
entry_count = count_entries(file, format)
return (file, iseq, objclass, urib, entry_count)

View File

@ -2,6 +2,8 @@
from obitools3.dms.capi.obitypes cimport obitype_t, index_t
cpdef bytes format_separator(bytes format)
cpdef int count_entries(file, bytes format)
cdef obi_errno_to_exception(int obi_errno, index_t line_nb=*, object elt_id=*, str error_message=*)

View File

@ -16,6 +16,34 @@ from obitools3.dms.capi.obierrno cimport OBI_LINE_IDX_ERROR, \
#obi_errno
import re
import mmap
cpdef bytes format_separator(bytes format):
if format == b"fasta":
return b"\n>"
elif format == b"fastq":
return b"\n@"
elif format == b"ngsfilter" or format == b"tabular":
return b"\n"
elif format == b"genbank" or format == b"embl":
return b"\n//"
elif format == b"ecopcr":
return b"\n[^#]"
else:
return None
cpdef int count_entries(file, bytes format):
try:
sep = format_separator(format)
if sep is None:
return -1
sep = re.compile(sep)
mmapped_file = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ)
return len(re.findall(sep, mmapped_file))
except:
return -1
# TODO RollbackException?