Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import mmcv | |
| import mmengine | |
| import numpy as np | |
| from rich.progress import track | |
| from mmocr.registry import VISUALIZERS | |
| from mmocr.structures import TextSpottingDataSample | |
| from mmocr.utils import ConfigType, bbox2poly, crop_img, poly2bbox | |
| from .base_mmocr_inferencer import (BaseMMOCRInferencer, InputsType, PredType, | |
| ResType) | |
| from .kie_inferencer import KIEInferencer | |
| from .textdet_inferencer import TextDetInferencer | |
| from .textrec_inferencer import TextRecInferencer | |
| class MMOCRInferencer(BaseMMOCRInferencer): | |
| """MMOCR Inferencer. It's a wrapper around three base task | |
| inferenecers: TextDetInferencer, TextRecInferencer and KIEInferencer, | |
| and it can be used to perform end-to-end OCR or KIE inference. | |
| Args: | |
| det (Optional[Union[ConfigType, str]]): Pretrained text detection | |
| algorithm. It's the path to the config file or the model name | |
| defined in metafile. Defaults to None. | |
| det_weights (Optional[str]): Path to the custom checkpoint file of | |
| the selected det model. If it is not specified and "det" is a model | |
| name of metafile, the weights will be loaded from metafile. | |
| Defaults to None. | |
| rec (Optional[Union[ConfigType, str]]): Pretrained text recognition | |
| algorithm. It's the path to the config file or the model name | |
| defined in metafile. Defaults to None. | |
| rec_weights (Optional[str]): Path to the custom checkpoint file of | |
| the selected rec model. If it is not specified and "rec" is a model | |
| name of metafile, the weights will be loaded from metafile. | |
| Defaults to None. | |
| kie (Optional[Union[ConfigType, str]]): Pretrained key information | |
| extraction algorithm. It's the path to the config file or the model | |
| name defined in metafile. Defaults to None. | |
| kie_weights (Optional[str]): Path to the custom checkpoint file of | |
| the selected kie model. If it is not specified and "kie" is a model | |
| name of metafile, the weights will be loaded from metafile. | |
| Defaults to None. | |
| device (Optional[str]): Device to run inference. If None, the available | |
| device will be automatically used. Defaults to None. | |
| """ | |
| def __init__(self, | |
| det: Optional[Union[ConfigType, str]] = None, | |
| det_weights: Optional[str] = None, | |
| rec: Optional[Union[ConfigType, str]] = None, | |
| rec_weights: Optional[str] = None, | |
| kie: Optional[Union[ConfigType, str]] = None, | |
| kie_weights: Optional[str] = None, | |
| device: Optional[str] = None) -> None: | |
| if det is None and rec is None and kie is None: | |
| raise ValueError('At least one of det, rec and kie should be ' | |
| 'provided.') | |
| self.visualizer = None | |
| if det is not None: | |
| self.textdet_inferencer = TextDetInferencer( | |
| det, det_weights, device) | |
| self.mode = 'det' | |
| if rec is not None: | |
| self.textrec_inferencer = TextRecInferencer( | |
| rec, rec_weights, device) | |
| if getattr(self, 'mode', None) == 'det': | |
| self.mode = 'det_rec' | |
| ts = str(datetime.timestamp(datetime.now())) | |
| self.visualizer = VISUALIZERS.build( | |
| dict( | |
| type='TextSpottingLocalVisualizer', | |
| name=f'inferencer{ts}', | |
| font_families=self.textrec_inferencer.visualizer. | |
| font_families)) | |
| else: | |
| self.mode = 'rec' | |
| if kie is not None: | |
| if det is None or rec is None: | |
| raise ValueError( | |
| 'kie_config is only applicable when det_config and ' | |
| 'rec_config are both provided') | |
| self.kie_inferencer = KIEInferencer(kie, kie_weights, device) | |
| self.mode = 'det_rec_kie' | |
| def _inputs2ndarrray(self, inputs: List[InputsType]) -> List[np.ndarray]: | |
| """Preprocess the inputs to a list of numpy arrays.""" | |
| new_inputs = [] | |
| for item in inputs: | |
| if isinstance(item, np.ndarray): | |
| new_inputs.append(item) | |
| elif isinstance(item, str): | |
| img_bytes = mmengine.fileio.get(item) | |
| new_inputs.append(mmcv.imfrombytes(img_bytes)) | |
| else: | |
| raise NotImplementedError(f'The input type {type(item)} is not' | |
| 'supported yet.') | |
| return new_inputs | |
| def forward(self, | |
| inputs: InputsType, | |
| batch_size: int = 1, | |
| det_batch_size: Optional[int] = None, | |
| rec_batch_size: Optional[int] = None, | |
| kie_batch_size: Optional[int] = None, | |
| **forward_kwargs) -> PredType: | |
| """Forward the inputs to the model. | |
| Args: | |
| inputs (InputsType): The inputs to be forwarded. | |
| batch_size (int): Batch size. Defaults to 1. | |
| det_batch_size (Optional[int]): Batch size for text detection | |
| model. Overwrite batch_size if it is not None. | |
| Defaults to None. | |
| rec_batch_size (Optional[int]): Batch size for text recognition | |
| model. Overwrite batch_size if it is not None. | |
| Defaults to None. | |
| kie_batch_size (Optional[int]): Batch size for KIE model. | |
| Overwrite batch_size if it is not None. | |
| Defaults to None. | |
| Returns: | |
| Dict: The prediction results. Possibly with keys "det", "rec", and | |
| "kie".. | |
| """ | |
| result = {} | |
| forward_kwargs['progress_bar'] = False | |
| if det_batch_size is None: | |
| det_batch_size = batch_size | |
| if rec_batch_size is None: | |
| rec_batch_size = batch_size | |
| if kie_batch_size is None: | |
| kie_batch_size = batch_size | |
| if self.mode == 'rec': | |
| # The extra list wrapper here is for the ease of postprocessing | |
| self.rec_inputs = inputs | |
| predictions = self.textrec_inferencer( | |
| self.rec_inputs, | |
| return_datasamples=True, | |
| batch_size=rec_batch_size, | |
| **forward_kwargs)['predictions'] | |
| result['rec'] = [[p] for p in predictions] | |
| elif self.mode.startswith('det'): # 'det'/'det_rec'/'det_rec_kie' | |
| result['det'] = self.textdet_inferencer( | |
| inputs, | |
| return_datasamples=True, | |
| batch_size=det_batch_size, | |
| **forward_kwargs)['predictions'] | |
| if self.mode.startswith('det_rec'): # 'det_rec'/'det_rec_kie' | |
| result['rec'] = [] | |
| for img, det_data_sample in zip( | |
| self._inputs2ndarrray(inputs), result['det']): | |
| det_pred = det_data_sample.pred_instances | |
| self.rec_inputs = [] | |
| for polygon in det_pred['polygons']: | |
| # Roughly convert the polygon to a quadangle with | |
| # 4 points | |
| quad = bbox2poly(poly2bbox(polygon)).tolist() | |
| self.rec_inputs.append(crop_img(img, quad)) | |
| result['rec'].append( | |
| self.textrec_inferencer( | |
| self.rec_inputs, | |
| return_datasamples=True, | |
| batch_size=rec_batch_size, | |
| **forward_kwargs)['predictions']) | |
| if self.mode == 'det_rec_kie': | |
| self.kie_inputs = [] | |
| # TODO: when the det output is empty, kie will fail | |
| # as no gt-instances can be provided. It's a known | |
| # issue but cannot be solved elegantly since we support | |
| # batch inference. | |
| for img, det_data_sample, rec_data_samples in zip( | |
| inputs, result['det'], result['rec']): | |
| det_pred = det_data_sample.pred_instances | |
| kie_input = dict(img=img) | |
| kie_input['instances'] = [] | |
| for polygon, rec_data_sample in zip( | |
| det_pred['polygons'], rec_data_samples): | |
| kie_input['instances'].append( | |
| dict( | |
| bbox=poly2bbox(polygon), | |
| text=rec_data_sample.pred_text.item)) | |
| self.kie_inputs.append(kie_input) | |
| result['kie'] = self.kie_inferencer( | |
| self.kie_inputs, | |
| return_datasamples=True, | |
| batch_size=kie_batch_size, | |
| **forward_kwargs)['predictions'] | |
| return result | |
| def visualize(self, inputs: InputsType, preds: PredType, | |
| **kwargs) -> Union[List[np.ndarray], None]: | |
| """Visualize predictions. | |
| Args: | |
| inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer. | |
| preds (List[Dict]): Predictions of the model. | |
| show (bool): Whether to display the image in a popup window. | |
| Defaults to False. | |
| wait_time (float): The interval of show (s). Defaults to 0. | |
| draw_pred (bool): Whether to draw predicted bounding boxes. | |
| Defaults to True. | |
| pred_score_thr (float): Minimum score of bboxes to draw. | |
| Defaults to 0.3. | |
| save_vis (bool): Whether to save the visualization result. Defaults | |
| to False. | |
| img_out_dir (str): Output directory of visualization results. | |
| If left as empty, no file will be saved. Defaults to ''. | |
| Returns: | |
| List[np.ndarray] or None: Returns visualization results only if | |
| applicable. | |
| """ | |
| if 'kie' in self.mode: | |
| return self.kie_inferencer.visualize(self.kie_inputs, preds['kie'], | |
| **kwargs) | |
| elif 'rec' in self.mode: | |
| if 'det' in self.mode: | |
| return super().visualize(inputs, | |
| self._pack_e2e_datasamples(preds), | |
| **kwargs) | |
| else: | |
| return self.textrec_inferencer.visualize( | |
| self.rec_inputs, preds['rec'][0], **kwargs) | |
| else: | |
| return self.textdet_inferencer.visualize(inputs, preds['det'], | |
| **kwargs) | |
| def __call__( | |
| self, | |
| inputs: InputsType, | |
| batch_size: int = 1, | |
| det_batch_size: Optional[int] = None, | |
| rec_batch_size: Optional[int] = None, | |
| kie_batch_size: Optional[int] = None, | |
| out_dir: str = 'results/', | |
| return_vis: bool = False, | |
| save_vis: bool = False, | |
| save_pred: bool = False, | |
| **kwargs, | |
| ) -> dict: | |
| """Call the inferencer. | |
| Args: | |
| inputs (InputsType): Inputs for the inferencer. It can be a path | |
| to image / image directory, or an array, or a list of these. | |
| batch_size (int): Batch size. Defaults to 1. | |
| det_batch_size (Optional[int]): Batch size for text detection | |
| model. Overwrite batch_size if it is not None. | |
| Defaults to None. | |
| rec_batch_size (Optional[int]): Batch size for text recognition | |
| model. Overwrite batch_size if it is not None. | |
| Defaults to None. | |
| kie_batch_size (Optional[int]): Batch size for KIE model. | |
| Overwrite batch_size if it is not None. | |
| Defaults to None. | |
| out_dir (str): Output directory of results. Defaults to 'results/'. | |
| return_vis (bool): Whether to return the visualization result. | |
| Defaults to False. | |
| save_vis (bool): Whether to save the visualization results to | |
| "out_dir". Defaults to False. | |
| save_pred (bool): Whether to save the inference results to | |
| "out_dir". Defaults to False. | |
| **kwargs: Key words arguments passed to :meth:`preprocess`, | |
| :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. | |
| Each key in kwargs should be in the corresponding set of | |
| ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` | |
| and ``postprocess_kwargs``. | |
| Returns: | |
| dict: Inference and visualization results, mapped from | |
| "predictions" and "visualization". | |
| """ | |
| if (save_vis or save_pred) and not out_dir: | |
| raise ValueError('out_dir must be specified when save_vis or ' | |
| 'save_pred is True!') | |
| if out_dir: | |
| img_out_dir = osp.join(out_dir, 'vis') | |
| pred_out_dir = osp.join(out_dir, 'preds') | |
| else: | |
| img_out_dir, pred_out_dir = '', '' | |
| ( | |
| preprocess_kwargs, | |
| forward_kwargs, | |
| visualize_kwargs, | |
| postprocess_kwargs, | |
| ) = self._dispatch_kwargs( | |
| save_vis=save_vis, | |
| save_pred=save_pred, | |
| return_vis=return_vis, | |
| **kwargs) | |
| ori_inputs = self._inputs_to_list(inputs) | |
| if det_batch_size is None: | |
| det_batch_size = batch_size | |
| if rec_batch_size is None: | |
| rec_batch_size = batch_size | |
| if kie_batch_size is None: | |
| kie_batch_size = batch_size | |
| chunked_inputs = super(BaseMMOCRInferencer, | |
| self)._get_chunk_data(ori_inputs, batch_size) | |
| results = {'predictions': [], 'visualization': []} | |
| for ori_input in track(chunked_inputs, description='Inference'): | |
| preds = self.forward( | |
| ori_input, | |
| det_batch_size=det_batch_size, | |
| rec_batch_size=rec_batch_size, | |
| kie_batch_size=kie_batch_size, | |
| **forward_kwargs) | |
| visualization = self.visualize( | |
| ori_input, preds, img_out_dir=img_out_dir, **visualize_kwargs) | |
| batch_res = self.postprocess( | |
| preds, | |
| visualization, | |
| pred_out_dir=pred_out_dir, | |
| **postprocess_kwargs) | |
| results['predictions'].extend(batch_res['predictions']) | |
| if return_vis and batch_res['visualization'] is not None: | |
| results['visualization'].extend(batch_res['visualization']) | |
| return results | |
| def postprocess(self, | |
| preds: PredType, | |
| visualization: Optional[List[np.ndarray]] = None, | |
| print_result: bool = False, | |
| save_pred: bool = False, | |
| pred_out_dir: str = '' | |
| ) -> Union[ResType, Tuple[ResType, np.ndarray]]: | |
| """Process the predictions and visualization results from ``forward`` | |
| and ``visualize``. | |
| This method should be responsible for the following tasks: | |
| 1. Convert datasamples into a json-serializable dict if needed. | |
| 2. Pack the predictions and visualization results and return them. | |
| 3. Dump or log the predictions. | |
| Args: | |
| preds (PredType): Predictions of the model. | |
| visualization (Optional[np.ndarray]): Visualized predictions. | |
| print_result (bool): Whether to print the result. | |
| Defaults to False. | |
| save_pred (bool): Whether to save the inference result. Defaults to | |
| False. | |
| pred_out_dir: File to save the inference results w/o | |
| visualization. If left as empty, no file will be saved. | |
| Defaults to ''. | |
| Returns: | |
| Dict: Inference and visualization results, mapped from | |
| "predictions" and "visualization". | |
| """ | |
| result_dict = {} | |
| pred_results = [{} for _ in range(len(next(iter(preds.values()))))] | |
| if 'rec' in self.mode: | |
| for i, rec_pred in enumerate(preds['rec']): | |
| result = dict(rec_texts=[], rec_scores=[]) | |
| for rec_pred_instance in rec_pred: | |
| rec_dict_res = self.textrec_inferencer.pred2dict( | |
| rec_pred_instance) | |
| result['rec_texts'].append(rec_dict_res['text']) | |
| result['rec_scores'].append(rec_dict_res['scores']) | |
| pred_results[i].update(result) | |
| if 'det' in self.mode: | |
| for i, det_pred in enumerate(preds['det']): | |
| det_dict_res = self.textdet_inferencer.pred2dict(det_pred) | |
| pred_results[i].update( | |
| dict( | |
| det_polygons=det_dict_res['polygons'], | |
| det_scores=det_dict_res['scores'])) | |
| if 'kie' in self.mode: | |
| for i, kie_pred in enumerate(preds['kie']): | |
| kie_dict_res = self.kie_inferencer.pred2dict(kie_pred) | |
| pred_results[i].update( | |
| dict( | |
| kie_labels=kie_dict_res['labels'], | |
| kie_scores=kie_dict_res['scores']), | |
| kie_edge_scores=kie_dict_res['edge_scores'], | |
| kie_edge_labels=kie_dict_res['edge_labels']) | |
| if save_pred and pred_out_dir: | |
| pred_key = 'det' if 'det' in self.mode else 'rec' | |
| for pred, pred_result in zip(preds[pred_key], pred_results): | |
| img_path = ( | |
| pred.img_path if pred_key == 'det' else pred[0].img_path) | |
| pred_name = osp.splitext(osp.basename(img_path))[0] | |
| pred_name = f'{pred_name}.json' | |
| pred_out_file = osp.join(pred_out_dir, pred_name) | |
| mmengine.dump(pred_result, pred_out_file) | |
| result_dict['predictions'] = pred_results | |
| if print_result: | |
| print(result_dict) | |
| result_dict['visualization'] = visualization | |
| return result_dict | |
| def _pack_e2e_datasamples(self, | |
| preds: Dict) -> List[TextSpottingDataSample]: | |
| """Pack text detection and recognition results into a list of | |
| TextSpottingDataSample.""" | |
| results = [] | |
| for det_data_sample, rec_data_samples in zip(preds['det'], | |
| preds['rec']): | |
| texts = [] | |
| for rec_data_sample in rec_data_samples: | |
| texts.append(rec_data_sample.pred_text.item) | |
| det_data_sample.pred_instances.texts = texts | |
| results.append(det_data_sample) | |
| return results | |