Spaces:
Running
Running
Mohit Agarwal
go : run `go mod tidy` before building examples + fix permissions (#296)
8517b79 unverified | package main | |
| import ( | |
| "context" | |
| "flag" | |
| "fmt" | |
| "io" | |
| "net/http" | |
| "net/url" | |
| "os" | |
| "path/filepath" | |
| "syscall" | |
| "time" | |
| ) | |
| /////////////////////////////////////////////////////////////////////////////// | |
| // CONSTANTS | |
| const ( | |
| srcUrl = "https://huggingface.co/" // The location of the models | |
| srcPathPrefix = "/datasets/ggerganov/whisper.cpp/resolve/main/ggml" // Filename prefix | |
| srcExt = ".bin" // Filename extension | |
| bufSize = 1024 * 64 // Size of the buffer used for downloading the model | |
| ) | |
| var ( | |
| // The models which will be downloaded, if no model is specified as an argument | |
| modelNames = []string{"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium", "large-v1", "large"} | |
| ) | |
| var ( | |
| // The output folder. When not set, use current working directory. | |
| flagOut = flag.String("out", "", "Output folder") | |
| // HTTP timeout parameter - will timeout if takes longer than this to download a model | |
| flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout") | |
| // Quiet parameter - will not print progress if set | |
| flagQuiet = flag.Bool("quiet", false, "Quiet mode") | |
| ) | |
| /////////////////////////////////////////////////////////////////////////////// | |
| // MAIN | |
| func main() { | |
| flag.Usage = func() { | |
| name := filepath.Base(flag.CommandLine.Name()) | |
| fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] <model>\n\n", name) | |
| flag.PrintDefaults() | |
| } | |
| flag.Parse() | |
| // Get output path | |
| out, err := GetOut() | |
| if err != nil { | |
| fmt.Fprintln(os.Stderr, "Error:", err) | |
| os.Exit(-1) | |
| } | |
| // Create context which quits on SIGINT or SIGQUIT | |
| ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT) | |
| // Progress filehandle | |
| progress := os.Stdout | |
| if *flagQuiet { | |
| progress, err = os.Open(os.DevNull) | |
| if err != nil { | |
| fmt.Fprintln(os.Stderr, "Error:", err) | |
| os.Exit(-1) | |
| } | |
| defer progress.Close() | |
| } | |
| // Download models - exit on error or interrupt | |
| for _, model := range GetModels() { | |
| url, err := URLForModel(model) | |
| if err != nil { | |
| fmt.Fprintln(os.Stderr, "Error:", err) | |
| continue | |
| } else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF { | |
| continue | |
| } else if err == context.Canceled { | |
| os.Remove(path) | |
| fmt.Fprintln(progress, "\nInterrupted") | |
| break | |
| } else if err == context.DeadlineExceeded { | |
| os.Remove(path) | |
| fmt.Fprintln(progress, "Timeout downloading model") | |
| continue | |
| } else { | |
| os.Remove(path) | |
| fmt.Fprintln(os.Stderr, "Error:", err) | |
| break | |
| } | |
| } | |
| } | |
| /////////////////////////////////////////////////////////////////////////////// | |
| // PUBLIC METHODS | |
| // GetOut returns the path to the output directory | |
| func GetOut() (string, error) { | |
| if *flagOut == "" { | |
| return os.Getwd() | |
| } | |
| if info, err := os.Stat(*flagOut); err != nil { | |
| return "", err | |
| } else if !info.IsDir() { | |
| return "", fmt.Errorf("not a directory: %s", info.Name()) | |
| } else { | |
| return *flagOut, nil | |
| } | |
| } | |
| // GetModels returns the list of models to download | |
| func GetModels() []string { | |
| if flag.NArg() == 0 { | |
| return modelNames | |
| } else { | |
| return flag.Args() | |
| } | |
| } | |
| // URLForModel returns the URL for the given model on huggingface.co | |
| func URLForModel(model string) (string, error) { | |
| url, err := url.Parse(srcUrl) | |
| if err != nil { | |
| return "", err | |
| } else { | |
| url.Path = srcPathPrefix + "-" + model + srcExt | |
| } | |
| return url.String(), nil | |
| } | |
| // Download downloads the model from the given URL to the given output directory | |
| func Download(ctx context.Context, p io.Writer, model, out string) (string, error) { | |
| // Create HTTP client | |
| client := http.Client{ | |
| Timeout: *flagTimeout, | |
| } | |
| // Initiate the download | |
| req, err := http.NewRequest("GET", model, nil) | |
| if err != nil { | |
| return "", err | |
| } | |
| resp, err := client.Do(req) | |
| if err != nil { | |
| return "", err | |
| } | |
| defer resp.Body.Close() | |
| if resp.StatusCode != http.StatusOK { | |
| return "", fmt.Errorf("%s: %s", model, resp.Status) | |
| } | |
| // If output file exists and is the same size as the model, skip | |
| path := filepath.Join(out, filepath.Base(model)) | |
| if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength { | |
| fmt.Fprintln(p, "Skipping", model, "as it already exists") | |
| return "", nil | |
| } | |
| // Create file | |
| w, err := os.Create(path) | |
| if err != nil { | |
| return "", err | |
| } | |
| defer w.Close() | |
| // Report | |
| fmt.Fprintln(p, "Downloading", model, "to", out) | |
| // Progressively download the model | |
| data := make([]byte, bufSize) | |
| count, pct := int64(0), int64(0) | |
| ticker := time.NewTicker(5 * time.Second) | |
| for { | |
| select { | |
| case <-ctx.Done(): | |
| // Cancelled, return error | |
| return path, ctx.Err() | |
| case <-ticker.C: | |
| pct = DownloadReport(p, pct, count, resp.ContentLength) | |
| default: | |
| // Read body | |
| n, err := resp.Body.Read(data) | |
| if err != nil { | |
| DownloadReport(p, pct, count, resp.ContentLength) | |
| return path, err | |
| } else if m, err := w.Write(data[:n]); err != nil { | |
| return path, err | |
| } else { | |
| count += int64(m) | |
| } | |
| } | |
| } | |
| } | |
| // Report periodically reports the download progress when percentage changes | |
| func DownloadReport(w io.Writer, pct, count, total int64) int64 { | |
| pct_ := count * 100 / total | |
| if pct_ > pct { | |
| fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_) | |
| } | |
| return pct_ | |
| } | |