diff --git a/src/autolabel/dataset/dataset.py b/src/autolabel/dataset/dataset.py index 4b2ee496..f4082c16 100644 --- a/src/autolabel/dataset/dataset.py +++ b/src/autolabel/dataset/dataset.py @@ -159,9 +159,9 @@ def process_labels( ) else: attr_confidence_scores.append(0.0) - self.df[ - self.generate_label_name("confidence", attr["name"]) - ] = attr_confidence_scores + self.df[self.generate_label_name("confidence", attr["name"])] = ( + attr_confidence_scores + ) # Add the LLM explanations to the dataframe if chain of thought is set in config if self.config.chain_of_thought(): @@ -377,7 +377,7 @@ def _validate(self): f"Validation failed for {len(self.__malformed_records)} rows.", ) - def generate_label_name(self, col_name: str, label_column: str = None): + def generate_label_name(self, col_name: str, label_column: str = None) -> str: label_column = label_column or f"{self.config.task_name()}_task" return f"{label_column}_{col_name}" diff --git a/src/autolabel/models/openai_vision.py b/src/autolabel/models/openai_vision.py index 4c3d2727..f24b00d4 100644 --- a/src/autolabel/models/openai_vision.py +++ b/src/autolabel/models/openai_vision.py @@ -84,6 +84,7 @@ def __init__( ) self.tiktoken = tiktoken self.image_cols = config.image_columns() + self.input_cols = config.input_columns() def _label(self, prompts: List[str], output_schema: Dict) -> RefuelLLMResult: generations = [] @@ -95,7 +96,8 @@ def _label(self, prompts: List[str], output_schema: Dict) -> RefuelLLMResult: if self.image_cols: for col in self.image_cols: if ( - parsed_prompt.get(col) is not None + col in self.input_cols + and parsed_prompt.get(col) is not None and len(parsed_prompt[col]) > 0 ): content.append( diff --git a/src/autolabel/task_chain/task_chain.py b/src/autolabel/task_chain/task_chain.py index 80fa4996..ae48b2d3 100644 --- a/src/autolabel/task_chain/task_chain.py +++ b/src/autolabel/task_chain/task_chain.py @@ -136,6 +136,7 @@ def __init__( self.confidence_endpoint = confidence_endpoint self.column_name_map = column_name_map self.label_selector_map = label_selector_map + self.s3_client = boto3.client("s3") # TODO: For now, we run each separate step of the task chain serially and aggregate at the end. # We can optimize this with parallelization where possible/no dependencies. @@ -155,6 +156,10 @@ async def run(self, dataset_df: pd.DataFrame): for task in subtasks: autolabel_config = AutolabelConfig(task) dataset = AutolabelDataset(dataset_df, autolabel_config) + dataset, original_inputs = self.safe_convert_uri_to_presigned_url( + dataset, + autolabel_config, + ) if autolabel_config.transforms(): agent = LabelingAgent( config=autolabel_config, @@ -191,6 +196,11 @@ async def run(self, dataset_df: pd.DataFrame): dataset, skip_eval=True, ) + dataset = self.reset_presigned_url_to_uri( + dataset, + original_inputs, + autolabel_config, + ) dataset = self.rename_output_columns(dataset, autolabel_config) dataset_df = dataset.df return dataset @@ -218,3 +228,34 @@ def rename_output_columns( ].apply(lambda x: x.get(attribute) if x and type(x) is dict else None) return dataset + + def safe_convert_uri_to_presigned_url( + self, + dataset: AutolabelDataset, + autolabel_config: AutolabelConfig, + ) -> Tuple[AutolabelDataset, List[Dict]]: + original_inputs = copy.deepcopy(dataset.inputs) + for col in autolabel_config.input_columns(): + for i in range(len(dataset.inputs)): + dataset.inputs[i][col] = ( + generate_presigned_url( + self.s3_client, + dataset.inputs[i][col], + ) + if is_s3_uri(dataset.inputs[i][col]) + else dataset.inputs[i][col] + ) + dataset.df.loc[i, col] = dataset.inputs[i][col] + return dataset, original_inputs + + def reset_presigned_url_to_uri( + self, + dataset: AutolabelDataset, + original_inputs: List[Dict], + autolabel_config: AutolabelConfig, + ) -> AutolabelDataset: + for col in autolabel_config.input_columns(): + for i in range(len(dataset.inputs)): + dataset.inputs[i][col] = original_inputs[i][col] + dataset.df.loc[i, col] = dataset.inputs[i][col] + return dataset diff --git a/src/autolabel/tasks/attribute_extraction.py b/src/autolabel/tasks/attribute_extraction.py index 6c131269..d6ed8415 100644 --- a/src/autolabel/tasks/attribute_extraction.py +++ b/src/autolabel/tasks/attribute_extraction.py @@ -281,7 +281,11 @@ def construct_prompt( if self.image_cols: prompt_dict = {"text": curr_text_prompt} for col in self.image_cols: - if input.get(col) is not None and len(input.get(col)) > 0: + if ( + col in self.input_cols + and input.get(col) is not None + and len(input.get(col)) > 0 + ): prompt_dict[col] = input[col] prompt_dict[col] = input[col] return json.dumps(prompt_dict), output_schema diff --git a/src/autolabel/tasks/base.py b/src/autolabel/tasks/base.py index 8c203a64..969e3b58 100644 --- a/src/autolabel/tasks/base.py +++ b/src/autolabel/tasks/base.py @@ -37,7 +37,7 @@ class BaseTask(ABC): def __init__(self, config: AutolabelConfig) -> None: self.config = config self.image_cols = self.config.image_columns() - + self.input_cols = self.config.input_columns() # Update the default prompt template with the prompt template from the config self.task_guidelines = ( self.config.task_guidelines() or self.DEFAULT_TASK_GUIDELINES diff --git a/src/autolabel/transforms/ocr.py b/src/autolabel/transforms/ocr.py index 5855f39e..74ae8251 100644 --- a/src/autolabel/transforms/ocr.py +++ b/src/autolabel/transforms/ocr.py @@ -160,7 +160,7 @@ async def _apply(self, row: dict[str, Any]) -> dict[str, Any]: ) from exc ocr_output = [] - if curr_file_path.endswith(".pdf"): + if Path(curr_file_path).suffix.lower().startswith(".pdf"): pages = self.convert_from_path(curr_file_path) ocr_output = [ self.default_ocr_processor(page, lang=self.lang) for page in pages diff --git a/src/autolabel/transforms/serp_api.py b/src/autolabel/transforms/serp_api.py index 76e554db..f627a2c5 100644 --- a/src/autolabel/transforms/serp_api.py +++ b/src/autolabel/transforms/serp_api.py @@ -101,6 +101,7 @@ async def _get_result(self, query): return search_result async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]: + start_time = time.time() for col in self.query_columns: if col not in row: logger.warning( @@ -124,6 +125,10 @@ async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]: "organic_results", ), } + end_time = time.time() + logger.error( + f"Time taken to run Serp API: {end_time - start_time} seconds", + ) return self._return_output_row(transformed_row) diff --git a/src/autolabel/utils.py b/src/autolabel/utils.py index 49dbb407..63cf851c 100644 --- a/src/autolabel/utils.py +++ b/src/autolabel/utils.py @@ -8,6 +8,7 @@ import string from string import Formatter from typing import Any, Dict, Iterable, List, Optional, Sequence, Union +from urllib.parse import urlparse import regex import wget @@ -438,3 +439,38 @@ def safe_serialize_to_string(data: Dict) -> Dict: except Exception: ret[k] = "" return ret + + +def is_s3_uri(uri_string: str) -> bool: + return uri_string is not None and ( + uri_string.startswith("s3://") or uri_string.startswith("s3a://") + ) + + +def extract_bucket_key_from_s3_url(s3_path: str): + # Refer: https://stackoverflow.com/a/48245084 + if not is_s3_uri(s3_path): + logger.warning("URI is not actually an S3 URI: {}", s3_path) + return None + + path_object = urlparse(s3_path) + bucket = path_object.netloc + key = path_object.path + return {"Bucket": bucket, "Key": key.lstrip("/")} + + +def generate_s3_uri_from_bucket_key(bucket: str, key: str) -> str: + return f"s3://{bucket}/{key}" + + +def generate_presigned_url(client, s3_uri, expiration=86400): + s3_params = extract_bucket_key_from_s3_url(s3_uri) + + if not s3_params: + return s3_uri + + return client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": s3_params["Bucket"], "Key": s3_params["Key"]}, + ExpiresIn=expiration, + )