import pysam
import argparse

def restore_full_length(read, reference, target_length=101):
    """
    Restores a read to its full length by extending using the reference genome.

    Parameters:
    - read (pysam.AlignedSegment): The aligned read object.
    - reference (pysam.FastaFile): The reference genome FASTA file.
    - target_length (int): The desired full length of the restored read.

    Returns:
    - str: The restored sequence with fixed length, or None if it cannot be restored.
    """
    sequence = read.query_sequence
    if not sequence or not read.reference_name:
        return None  # Skip invalid or unmapped reads

    # Calculate how many bases are needed to reach target length
    missing_bases = target_length - len(sequence)

    # Extend to the left (before reference start)
    if missing_bases > 0:
        left_extension = reference.fetch(
            read.reference_name,
            max(0, read.reference_start - missing_bases),
            read.reference_start
        )
        right_extension = reference.fetch(
            read.reference_name,
            read.reference_end,
            min(read.reference_end + missing_bases, reference.get_reference_length(read.reference_name))
        )
        sequence = left_extension + sequence + right_extension

        # Truncate to target length if over-extended
        sequence = sequence[:target_length]

    # Ensure sequence is exactly the target length
    if len(sequence) != target_length:
        return None  # Drop reads that cannot be restored to full length

    return sequence

def process_bam_to_fastq(bam_file, reference_fasta, r1_output, r2_output, target_length=101):
    """
    Converts a BAM file to paired-end FASTQ, restoring full-length reads and filtering by length.

    Parameters:
    - bam_file (str): Path to the input BAM file.
    - reference_fasta (str): Path to the reference genome FASTA file.
    - r1_output (str): Path to the output FASTQ file for Read 1.
    - r2_output (str): Path to the output FASTQ file for Read 2.
    - target_length (int): Desired length of the restored reads.
    """
    reference = pysam.FastaFile(reference_fasta)
    bam = pysam.AlignmentFile(bam_file, "rb")

    with open(r1_output, "w") as r1_out, open(r2_output, "w") as r2_out:
        read_pairs = {}

        for read in bam:
            if read.is_unmapped:
                continue  # Skip unmapped reads

            # Restore the full-length sequence
            restored_sequence = restore_full_length(read, reference, target_length)

            if not restored_sequence:
                continue  # Skip reads that cannot be restored to full length

            # Pad quality scores to match restored sequence length
            quality_padding = "I" * (len(restored_sequence) - len(read.qual))
            full_quality = read.qual + quality_padding

            # Store read information
            if read.is_read1:
                read_pairs[read.query_name] = (restored_sequence, full_quality, None, None)
            elif read.is_read2:
                if read.query_name not in read_pairs:
                    read_pairs[read.query_name] = (None, None, restored_sequence, full_quality)
                else:
                    r1_seq, r1_qual, _, _ = read_pairs[read.query_name]
                    read_pairs[read.query_name] = (r1_seq, r1_qual, restored_sequence, full_quality)

        # Write paired reads of the correct length
        for read_name, (r1_seq, r1_qual, r2_seq, r2_qual) in read_pairs.items():
            if r1_seq and r2_seq:  # Ensure both reads are present
                if len(r1_seq) == target_length and len(r2_seq) == target_length:
                    r1_out.write(f"@{read_name}/1\n{r1_seq}\n+\n{r1_qual}\n")
                    r2_out.write(f"@{read_name}/2\n{r2_seq}\n+\n{r2_qual}\n")

    reference.close()
    bam.close()

def parse_arguments():
    """
    Parses command-line arguments for the script.

    Returns:
    argparse.Namespace: Parsed command-line arguments.
    """
    parser = argparse.ArgumentParser(
        description="Convert BAM to paired-end FASTQ and restore full-length reads using a reference FASTA."
    )
    parser.add_argument("-b", "--bam", required=True, help="Path to the input BAM file.")
    parser.add_argument("-r", "--reference", required=True, help="Path to the reference FASTA file.")
    parser.add_argument("-o1", "--output1", required=True, help="Path to the output FASTQ file for Read 1.")
    parser.add_argument("-o2", "--output2", required=True, help="Path to the output FASTQ file for Read 2.")
    parser.add_argument(
        "-l", "--length", type=int, default=101, help="Target length of restored reads (default: 101)."
    )
    return parser.parse_args()

def main():
    """
    Main function to convert BAM to full-length paired-end FASTQ.
    """
    args = parse_arguments()
    process_bam_to_fastq(args.bam, args.reference, args.output1, args.output2, args.length)

if __name__ == "__main__":
    main()

