| from transformers import PretrainedConfig, PreTrainedModel, Pipeline
|
| import torch
|
|
|
| from BeamDiffusionModel.beamInference import beam_inference
|
| from BeamDiffusionModel.models.diffusionModel.StableDiffusion import StableDiffusion
|
| from BeamDiffusionModel.models.diffusionModel.Flux import Flux
|
|
|
| class BeamDiffusionConfig(PretrainedConfig):
|
| model_type = "beam_diffusion"
|
| def __init__(self, sd="SD-2.1",latents_idx=None, n_seeds=4, seeds=None, steps_back=2, beam_width=4, window_size=2, use_rand=True, **kwargs):
|
| super().__init__(**kwargs)
|
| self.sd_name = sd
|
| self.sd = None
|
| self.get_model(sd)
|
| self.latents_idx = latents_idx if latents_idx else [0, 1, 2, 3]
|
| self.n_seeds = n_seeds
|
| self.seeds = seeds if seeds else []
|
| self.steps_back = steps_back
|
| self.beam_width = beam_width
|
| self.window_size = window_size
|
| self.use_rand = use_rand
|
|
|
| def get_model(self, sd):
|
| if self.sd_name == "flux":
|
| self.sd = Flux()
|
| elif self.sd_name == "SD-2.1":
|
| self.sd = StableDiffusion()
|
|
|
| import torch.nn as nn
|
| from huggingface_hub import ModelHubMixin
|
|
|
| class BeamDiffusionModel(PreTrainedModel, ModelHubMixin):
|
| config_class = BeamDiffusionConfig
|
| model_type = "beam_diffusion"
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
| self.config = config
|
| self.dummy_param = nn.Parameter(torch.zeros(1))
|
|
|
| def forward(self, input_data):
|
| images = beam_inference(
|
| self.config.sd,
|
| steps=input_data.get('steps', []),
|
| latents_idx=self.config.latents_idx,
|
| n_seeds=self.config.n_seeds,
|
| seeds=self.config.seeds,
|
| steps_back=self.config.steps_back,
|
| beam_width=self.config.beam_width,
|
| window_size=self.config.window_size,
|
| use_rand=self.config.use_rand,
|
| )
|
| return {"images": images}
|
|
|
|
|
|
|
| class BeamDiffusionPipeline(Pipeline, ModelHubMixin):
|
| def __init__(self, model, tokenizer=None, device="cuda", framework="pt"):
|
| super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework)
|
|
|
| def __call__(self, inputs):
|
| return self._forward(inputs)
|
|
|
| def preprocess(self, inputs):
|
| """Converts raw input data into model-ready format."""
|
| return inputs
|
|
|
| def postprocess(self, model_outputs):
|
| """Processes model output into a user-friendly format."""
|
| return model_outputs["images"]
|
|
|
| def _sanitize_parameters(self, **kwargs):
|
| """Handles unused parameters gracefully."""
|
| return {}, {}, {}
|
|
|
| def _forward(self, model_inputs):
|
| return self.model(model_inputs)
|
|
|