DNA-Diffusion / app.py
openfree's picture
Update app.py
3499851 verified
raw
history blame
26.9 kB
"""
Enhanced DNA-Diffusion Gradio Application
With scientific tools, analysis features, and LLM chat integration
"""
import gradio as gr
import logging
import json
import os
from typing import Dict, Any, Tuple, List
import html
import requests
import time
import numpy as np
from dataclasses import dataclass
from datetime import datetime
import asyncio
import aiohttp
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Try to import spaces for GPU decoration
try:
import spaces
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
class spaces:
@staticmethod
def GPU(duration=60):
def decorator(func):
return func
return decorator
# Try to import model
try:
from dna_diffusion_model import DNADiffusionModel, get_model
MODEL_AVAILABLE = True
except ImportError as e:
logger.warning(f"DNA-Diffusion model not available: {e}")
MODEL_AVAILABLE = False
# Load the enhanced HTML interface
HTML_FILE = "enhanced-dna-interface.html"
# Codon table for translation
CODON_TABLE = {
'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
}
# Common restriction enzymes
RESTRICTION_ENZYMES = {
'EcoRI': 'GAATTC',
'BamHI': 'GGATCC',
'HindIII': 'AAGCTT',
'PstI': 'CTGCAG',
'SalI': 'GTCGAC',
'XbaI': 'TCTAGA',
'NotI': 'GCGGCCGC',
'XhoI': 'CTCGAG',
'NdeI': 'CATATG',
'NcoI': 'CCATGG'
}
@dataclass
class AnalysisResult:
"""Data class for storing analysis results"""
sequence: str
gc_content: float
melting_temp: float
restriction_sites: Dict[str, List[int]]
orfs: List[Tuple[int, int, str]]
primers: Dict[str, Any]
protein_analysis: str
class ScientificAnalyzer:
"""Enhanced scientific analysis tools"""
@staticmethod
def calculate_gc_content(sequence: str) -> float:
"""Calculate GC content percentage"""
gc_count = sequence.count('G') + sequence.count('C')
return (gc_count / len(sequence)) * 100 if sequence else 0
@staticmethod
def calculate_melting_temp(sequence: str) -> float:
"""Calculate melting temperature using nearest neighbor method"""
if len(sequence) < 14:
# Wallace rule for short sequences
return 4 * (sequence.count('G') + sequence.count('C')) + 2 * (sequence.count('A') + sequence.count('T'))
else:
# Salt-adjusted melting temperature
gc_content = ScientificAnalyzer.calculate_gc_content(sequence)
return 81.5 + 0.41 * gc_content - 675 / len(sequence)
@staticmethod
def find_restriction_sites(sequence: str) -> Dict[str, List[int]]:
"""Find restriction enzyme cut sites"""
sites = {}
for enzyme, pattern in RESTRICTION_ENZYMES.items():
positions = []
for i in range(len(sequence) - len(pattern) + 1):
if sequence[i:i+len(pattern)] == pattern:
positions.append(i)
if positions:
sites[enzyme] = positions
return sites
@staticmethod
def find_orfs(sequence: str, min_length: int = 100) -> List[Tuple[int, int, str]]:
"""Find open reading frames"""
orfs = []
start_codon = 'ATG'
stop_codons = ['TAA', 'TAG', 'TGA']
for frame in range(3):
i = frame
while i < len(sequence) - 2:
codon = sequence[i:i+3]
if codon == start_codon:
# Found start codon, look for stop
for j in range(i + 3, len(sequence) - 2, 3):
codon = sequence[j:j+3]
if codon in stop_codons:
if j - i >= min_length:
orfs.append((i, j + 3, f"Frame +{frame + 1}"))
i = j
break
i += 3
return orfs
@staticmethod
def design_primers(sequence: str, product_size: int = 500) -> Dict[str, Any]:
"""Design PCR primers for the sequence"""
primer_length = 20
primers = []
# Find suitable primer regions
for start in range(0, len(sequence) - product_size, 100):
forward = sequence[start:start + primer_length]
reverse_start = start + product_size - primer_length
if reverse_start < len(sequence):
reverse = sequence[reverse_start:reverse_start + primer_length]
reverse_comp = ScientificAnalyzer.reverse_complement(reverse)
# Calculate primer properties
forward_tm = ScientificAnalyzer.calculate_melting_temp(forward)
reverse_tm = ScientificAnalyzer.calculate_melting_temp(reverse_comp)
if abs(forward_tm - reverse_tm) < 5: # Similar Tm
primers.append({
'forward': forward,
'reverse': reverse_comp,
'forward_tm': forward_tm,
'reverse_tm': reverse_tm,
'product_size': product_size,
'position': start
})
return primers[0] if primers else None
@staticmethod
def reverse_complement(sequence: str) -> str:
"""Get reverse complement of DNA sequence"""
complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
return ''.join(complement.get(base, base) for base in reversed(sequence))
@staticmethod
def codon_optimize(protein_sequence: str, organism: str = "E.coli") -> str:
"""Optimize codons for expression in target organism"""
# Simplified codon optimization - in reality would use organism-specific tables
ecoli_preferred_codons = {
'F': 'TTT', 'L': 'CTG', 'S': 'TCT', 'Y': 'TAT',
'C': 'TGC', 'W': 'TGG', 'P': 'CCG', 'H': 'CAT',
'Q': 'CAG', 'R': 'CGT', 'I': 'ATT', 'M': 'ATG',
'T': 'ACC', 'N': 'AAC', 'K': 'AAA', 'V': 'GTT',
'A': 'GCT', 'D': 'GAT', 'E': 'GAA', 'G': 'GGT'
}
optimized_dna = ""
for aa in protein_sequence:
if aa in ecoli_preferred_codons:
optimized_dna += ecoli_preferred_codons[aa]
return optimized_dna
class ProteinStructurePredictor:
"""3D protein structure prediction using external APIs"""
@staticmethod
async def predict_structure(protein_sequence: str) -> Dict[str, Any]:
"""Mock structure prediction - would integrate with AlphaFold API"""
# Simplified structure prediction
structure_data = {
'confidence': np.random.uniform(70, 95),
'secondary_structure': ProteinStructurePredictor._predict_secondary_structure(protein_sequence),
'domains': ProteinStructurePredictor._predict_domains(protein_sequence),
'pdb_data': None # Would contain actual 3D coordinates
}
return structure_data
@staticmethod
def _predict_secondary_structure(sequence: str) -> str:
"""Simple secondary structure prediction"""
structure = []
for i, aa in enumerate(sequence):
if aa in 'VILMFYW': # Hydrophobic - likely beta sheet
structure.append('B')
elif aa in 'DEKR': # Charged - likely loop
structure.append('L')
else: # Mixed - likely helix
structure.append('H')
return ''.join(structure)
@staticmethod
def _predict_domains(sequence: str) -> List[Dict[str, Any]]:
"""Predict protein domains"""
domains = []
# Mock domain prediction
if 'CXXC' in sequence or sequence.count('C') > len(sequence) * 0.1:
domains.append({
'name': 'Zinc finger domain',
'start': 0,
'end': 30,
'confidence': 85
})
return domains
class LLMChatAssistant:
"""LLM-powered scientific chat assistant"""
def __init__(self):
self.api_token = os.getenv("FRIENDLI_TOKEN")
self.conversation_history = []
async def chat(self, message: str, context: Dict[str, Any], language: str = "en") -> str:
"""Chat with the scientific assistant"""
if not self.api_token:
return "Chat unavailable: API token not configured"
try:
# Prepare context-aware prompt
system_prompt = self._build_system_prompt(language)
user_prompt = self._build_user_prompt(message, context, language)
# Add to conversation history
self.conversation_history.append({"role": "user", "content": message})
# Make API call
response = await self._call_llm_api(system_prompt, user_prompt)
# Add response to history
self.conversation_history.append({"role": "assistant", "content": response})
return response
except Exception as e:
logger.error(f"Chat error: {e}")
return f"Chat error: {str(e)}"
def _build_system_prompt(self, language: str) -> str:
"""Build system prompt for the assistant"""
if language == "ko":
return """당신은 분자생물학 전문가 AI 어시스턴트입니다.
DNA 시퀀스 분석, 단백질 구조 예측, 실험 설계, 프라이머 디자인 등을 도와드립니다.
과학적으로 정확하면서도 이해하기 쉽게 설명해드립니다."""
else:
return """You are an expert molecular biology AI assistant.
You help with DNA sequence analysis, protein structure prediction, experiment design, primer design, and more.
Provide scientifically accurate yet easy to understand explanations."""
def _build_user_prompt(self, message: str, context: Dict[str, Any], language: str) -> str:
"""Build context-aware user prompt"""
context_info = f"""
Current sequence: {context.get('sequence', 'None')[:50]}...
Cell type: {context.get('cell_type', 'Unknown')}
GC content: {context.get('gc_content', 'N/A')}%
Restriction sites found: {len(context.get('restriction_sites', {}))}
"""
return f"{context_info}\n\nUser question: {message}"
async def _call_llm_api(self, system_prompt: str, user_prompt: str) -> str:
"""Make async API call to LLM"""
url = "https://api.friendli.ai/dedicated/v1/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json"
}
payload = {
"model": "dep89a2fld32mcm",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"max_tokens": 500,
"temperature": 0.7
}
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as response:
result = await response.json()
return result['choices'][0]['message']['content']
class EnhancedDNAApp:
"""Main application class with enhanced features"""
def __init__(self):
self.model = None
self.model_loading = False
self.model_error = None
self.analyzer = ScientificAnalyzer()
self.structure_predictor = ProteinStructurePredictor()
self.chat_assistant = LLMChatAssistant()
self.current_analysis = None
def initialize_model(self):
"""Initialize the DNA-Diffusion model"""
if not MODEL_AVAILABLE:
self.model_error = "DNA-Diffusion model module not available"
return
if self.model_loading:
return
self.model_loading = True
try:
logger.info("Starting model initialization...")
self.model = get_model()
logger.info("Model initialized successfully!")
self.model_error = None
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
self.model_error = str(e)
self.model = None
finally:
self.model_loading = False
@spaces.GPU(duration=60)
def generate_and_analyze(self, cell_type: str, guidance_scale: float = 1.0, language: str = "en"):
"""Generate sequence and perform comprehensive analysis"""
try:
# Generate sequence
if MODEL_AVAILABLE and self.model:
result = self.model.generate(cell_type, guidance_scale)
sequence = result['sequence']
else:
# Mock generation
import random
sequence = ''.join(random.choice(['A', 'T', 'C', 'G']) for _ in range(200))
# Perform comprehensive analysis
analysis = self.analyze_sequence(sequence, cell_type)
# Store current analysis for chat context
self.current_analysis = {
'sequence': sequence,
'cell_type': cell_type,
'gc_content': analysis.gc_content,
'restriction_sites': analysis.restriction_sites,
'orfs': analysis.orfs,
'primers': analysis.primers
}
return json.dumps({
'sequence': sequence,
'analysis': {
'gc_content': analysis.gc_content,
'melting_temp': analysis.melting_temp,
'restriction_sites': analysis.restriction_sites,
'orfs': analysis.orfs,
'primers': analysis.primers,
'protein_analysis': analysis.protein_analysis
}
})
except Exception as e:
logger.error(f"Generation failed: {e}")
return json.dumps({"error": str(e)})
def analyze_sequence(self, sequence: str, cell_type: str) -> AnalysisResult:
"""Perform comprehensive sequence analysis"""
# Basic analysis
gc_content = self.analyzer.calculate_gc_content(sequence)
melting_temp = self.analyzer.calculate_melting_temp(sequence)
restriction_sites = self.analyzer.find_restriction_sites(sequence)
orfs = self.analyzer.find_orfs(sequence)
# Primer design
primers = self.analyzer.design_primers(sequence)
# Protein analysis
protein_seq = self.translate_to_protein(sequence)
protein_analysis = self.analyze_protein_basic(protein_seq)
return AnalysisResult(
sequence=sequence,
gc_content=gc_content,
melting_temp=melting_temp,
restriction_sites=restriction_sites,
orfs=orfs,
primers=primers,
protein_analysis=protein_analysis
)
def translate_to_protein(self, dna_sequence: str) -> str:
"""Translate DNA to protein"""
protein = []
for i in range(0, len(dna_sequence) - 2, 3):
codon = dna_sequence[i:i+3]
if len(codon) == 3:
aa = CODON_TABLE.get(codon, 'X')
if aa == '*':
break
protein.append(aa)
return ''.join(protein)
def analyze_protein_basic(self, protein_sequence: str) -> str:
"""Basic protein analysis"""
if not protein_sequence:
return "No protein sequence generated"
# Calculate basic properties
length = len(protein_sequence)
molecular_weight = sum(self.get_aa_weight(aa) for aa in protein_sequence)
# Count amino acid types
hydrophobic = sum(1 for aa in protein_sequence if aa in 'AILMFVPW')
charged = sum(1 for aa in protein_sequence if aa in 'DEKR')
analysis = f"""
Protein length: {length} amino acids
Molecular weight: ~{molecular_weight:.1f} Da
Hydrophobic residues: {hydrophobic} ({hydrophobic/length*100:.1f}%)
Charged residues: {charged} ({charged/length*100:.1f}%)
"""
return analysis
def get_aa_weight(self, aa: str) -> float:
"""Get amino acid molecular weight"""
weights = {
'A': 89.1, 'R': 174.2, 'N': 132.1, 'D': 133.1, 'C': 121.2,
'E': 147.1, 'Q': 146.2, 'G': 75.1, 'H': 155.2, 'I': 131.2,
'L': 131.2, 'K': 146.2, 'M': 149.2, 'F': 165.2, 'P': 115.1,
'S': 105.1, 'T': 119.1, 'W': 204.2, 'Y': 181.2, 'V': 117.1
}
return weights.get(aa, 100)
async def handle_chat(self, message: str, language: str = "en") -> str:
"""Handle chat messages"""
if not self.current_analysis:
return "Please generate a sequence first to get context-aware assistance."
response = await self.chat_assistant.chat(message, self.current_analysis, language)
return response
def export_results(self, format_type: str) -> str:
"""Export analysis results in various formats"""
if not self.current_analysis:
return "No analysis to export"
if format_type == "genbank":
return self._export_genbank()
elif format_type == "fasta":
return self._export_fasta()
elif format_type == "json":
return json.dumps(self.current_analysis, indent=2)
else:
return "Unsupported format"
def _export_fasta(self) -> str:
"""Export in FASTA format"""
header = f">DNA_Diffusion_{self.current_analysis['cell_type']}_{datetime.now().strftime('%Y%m%d')}"
return f"{header}\n{self.current_analysis['sequence']}"
def _export_genbank(self) -> str:
"""Export in GenBank format"""
# Simplified GenBank format
return f"""LOCUS DNA_Diffusion {len(self.current_analysis['sequence'])} bp DNA linear SYN {datetime.now().strftime('%d-%b-%Y')}
DEFINITION Synthetic DNA sequence for {self.current_analysis['cell_type']}
ORIGIN
1 {self.current_analysis['sequence']}
//"""
# Create single app instance
app = EnhancedDNAApp()
def create_enhanced_demo():
"""Create the enhanced Gradio interface"""
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown("# 🧬 Enhanced DNA-Diffusion with Scientific Tools")
with gr.Tabs():
with gr.TabItem("🎰 Generate & Analyze"):
with gr.Row():
with gr.Column(scale=2):
# Generation controls
cell_type = gr.Radio(
["K562", "GM12878", "HepG2"],
value="K562",
label="Cell Type"
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=10.0,
value=1.0,
step=0.5,
label="Guidance Scale"
)
language = gr.Radio(
["en", "ko"],
value="en",
label="Language"
)
generate_btn = gr.Button("🎲 Generate & Analyze", variant="primary")
with gr.Column(scale=3):
# Results display
results_json = gr.JSON(label="Analysis Results", visible=False)
# Visual results
with gr.Accordion("📊 Sequence Analysis", open=True):
gc_plot = gr.Plot(label="GC Content Distribution")
restriction_map = gr.Plot(label="Restriction Enzyme Map")
with gr.Accordion("🧬 Protein Analysis", open=True):
protein_structure = gr.HTML(label="Predicted Structure")
protein_properties = gr.Textbox(label="Properties", lines=5)
with gr.TabItem("💬 AI Assistant"):
chatbot = gr.Chatbot(label="Scientific Assistant", height=400)
msg = gr.Textbox(label="Ask about your sequence", placeholder="e.g., 'What primers would you recommend?'")
chat_btn = gr.Button("Send")
# Chat examples
gr.Examples(
examples=[
"What restriction enzymes should I use for cloning?",
"Can you explain the ORFs found in this sequence?",
"How can I optimize this sequence for E. coli expression?",
"What's the predicted protein structure?"
],
inputs=msg
)
with gr.TabItem("🔧 Tools"):
with gr.Row():
with gr.Column():
gr.Markdown("### Primer Design")
primer_length = gr.Slider(18, 25, 20, label="Primer Length")
product_size = gr.Slider(200, 1000, 500, label="Product Size")
design_primers_btn = gr.Button("Design Primers")
primer_results = gr.JSON(label="Designed Primers")
with gr.Column():
gr.Markdown("### Codon Optimization")
target_organism = gr.Dropdown(
["E. coli", "Yeast", "Human", "Mouse"],
value="E. coli",
label="Target Organism"
)
optimize_btn = gr.Button("Optimize Codons")
optimized_seq = gr.Textbox(label="Optimized Sequence", lines=5)
with gr.TabItem("📤 Export"):
export_format = gr.Radio(
["FASTA", "GenBank", "JSON"],
value="FASTA",
label="Export Format"
)
export_btn = gr.Button("Export Results")
export_output = gr.Textbox(label="Exported Data", lines=10)
# Wire up the interface
generate_btn.click(
fn=app.generate_and_analyze,
inputs=[cell_type, guidance_scale, language],
outputs=[results_json]
).then(
fn=visualize_results,
inputs=[results_json],
outputs=[gc_plot, restriction_map, protein_structure, protein_properties]
)
# Chat functionality
def respond(message, chat_history, language):
import asyncio
response = asyncio.run(app.handle_chat(message, language))
chat_history.append((message, response))
return "", chat_history
msg.submit(respond, [msg, chatbot, language], [msg, chatbot])
chat_btn.click(respond, [msg, chatbot, language], [msg, chatbot])
# Export functionality
export_btn.click(
fn=lambda fmt: app.export_results(fmt.lower()),
inputs=[export_format],
outputs=[export_output]
)
# Initialize model on load
demo.load(fn=app.initialize_model)
return demo
def visualize_results(results_json):
"""Create visualizations from analysis results"""
import matplotlib.pyplot as plt
import numpy as np
if isinstance(results_json, str):
data = json.loads(results_json)
else:
data = results_json
if "error" in data:
return None, None, "<p>Error in analysis</p>", "Error"
analysis = data.get('analysis', {})
# GC content plot
fig1, ax1 = plt.subplots(figsize=(8, 4))
gc_content = analysis.get('gc_content', 0)
ax1.bar(['GC%', 'AT%'], [gc_content, 100-gc_content], color=['#00ff00', '#ff0000'])
ax1.set_ylabel('Percentage')
ax1.set_title('Nucleotide Composition')
# Restriction map
fig2, ax2 = plt.subplots(figsize=(10, 3))
sites = analysis.get('restriction_sites', {})
seq_len = len(data.get('sequence', ''))
y_pos = 0
for enzyme, positions in sites.items():
for pos in positions:
ax2.plot([pos, pos], [y_pos-0.1, y_pos+0.1], 'r-', linewidth=2)
ax2.text(pos, y_pos+0.15, enzyme, fontsize=8, ha='center')
y_pos += 0.3
ax2.set_xlim(0, seq_len)
ax2.set_ylim(-0.5, max(0.5, y_pos))
ax2.set_xlabel('Position (bp)')
ax2.set_title('Restriction Enzyme Sites')
# Protein structure (mock visualization)
structure_html = """
<div style="padding: 20px; background: #f0f0f0; border-radius: 10px;">
<h3>🔬 Predicted Secondary Structure</h3>
<p>Helices: 45%, Beta sheets: 30%, Loops: 25%</p>
<div style="background: linear-gradient(to right, #ff0000 45%, #00ff00 30%, #0000ff 25%);
height: 30px; border-radius: 5px; margin: 10px 0;"></div>
<p style="color: #666;">3D structure prediction available in Pro version</p>
</div>
"""
# Protein properties
properties = analysis.get('protein_analysis', 'No analysis available')
return fig1, fig2, structure_html, properties
# Launch the enhanced app
if __name__ == "__main__":
demo = create_enhanced_demo()
demo.launch(share=True)