Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting sending images and files in Refuel Applications #938

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions src/autolabel/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Callable, Dict, List, Union, Optional
from typing import Callable, Dict, List, Optional, Union

import pandas as pd
from rich.console import Console
Expand Down Expand Up @@ -39,12 +39,14 @@ def __init__(
) -> None:
"""
Initializes the dataset.

Args:
dataset: The dataset to be used for labeling. Could be a path to a csv/jsonl file or a pandas dataframe.
config: The config to be used for labeling. Could be a path to a json file or a dictionary.
max_items: The maximum number of items to be parsed into the dataset object.
start_index: The index to start parsing the dataset from.
validate: Whether to validate the dataset or not.

"""
if not (isinstance(config, AutolabelConfig)):
self.config = AutolabelConfig(config)
Expand Down Expand Up @@ -105,7 +107,9 @@ def get_slice(self, max_items: int = None, start_index: int = 0):
return AutolabelDataset(df, self.config)

def process_labels(
self, llm_labels: List[LLMAnnotation], metrics: List[MetricResult] = None
self,
llm_labels: List[LLMAnnotation],
metrics: List[MetricResult] = None,
):
# Add the LLM labels to the dataframe
self.df[self.generate_label_name("label")] = [x.label for x in llm_labels]
Expand Down Expand Up @@ -152,13 +156,13 @@ def process_labels(
for x in llm_labels:
if x.successfully_labeled:
attr_confidence_scores.append(
x.confidence_score.get(attr["name"], 0.0)
x.confidence_score.get(attr["name"], 0.0),
)
else:
attr_confidence_scores.append(0.0)
self.df[
self.generate_label_name("confidence", attr["name"])
] = attr_confidence_scores
self.df[self.generate_label_name("confidence", attr["name"])] = (
attr_confidence_scores
)

# Add the LLM explanations to the dataframe if chain of thought is set in config
if self.config.chain_of_thought():
Expand All @@ -169,8 +173,10 @@ def process_labels(
def save(self, output_file_name: str):
"""
Saves the dataset to a file based on the file extension.

Args:
output_file_name: The name of the file to save the dataset to. Based on the extension we can save to a csv or jsonl file.

"""
if output_file_name.endswith(".csv"):
self.df.to_csv(
Expand Down Expand Up @@ -245,21 +251,26 @@ def completed(self):
return AutolabelDataset(filtered_df, self.config)

def incorrect(
self, label: str = None, ground_truth: str = None, label_column: str = None
self,
label: str = None,
ground_truth: str = None,
label_column: str = None,
):
"""
Filter the dataset to only include incorrect items. This means the labels
where the llm label was incorrect.

Args:
label: The llm label to filter on.
ground_truth: The ground truth label to filter on.
label_column: The column to filter on. This is only used for attribute extraction tasks.

"""
gt_label_column = label_column or self.config.label_column()

if gt_label_column is None:
raise ValueError(
"Cannot compute mistakes without ground truth label column"
"Cannot compute mistakes without ground truth label column",
)

filtered_df = self.df[
Expand All @@ -281,8 +292,10 @@ def correct(self, label_column: str = None):
"""
Filter the dataset to only include correct items. This means the labels
where the llm label was correct.

Args:
label_column: The column to filter on. This is only used for attribute extraction tasks.

"""
gt_label_column = label_column or self.config.label_column()

Expand All @@ -298,12 +311,14 @@ def correct(self, label_column: str = None):
def filter_by_confidence(self, threshold: float = 0.5):
"""
Filter the dataset to only include items with confidence scores greater than the threshold.

Args:
threshold: The threshold to filter on. This means that only items with confidence scores greater than the threshold will be included.

"""
if not self.config.confidence():
raise ValueError(
"Cannot compute correct and confident without confidence scores"
"Cannot compute correct and confident without confidence scores",
)

filtered_df = self.df[
Expand Down Expand Up @@ -360,13 +375,13 @@ def _validate(self):

if len(self.__malformed_records) > 0:
logger.warning(
f"Data Validation failed for {len(self.__malformed_records)} records: \n Stats: \n {table}"
f"Data Validation failed for {len(self.__malformed_records)} records: \n Stats: \n {table}",
)
raise DataValidationFailed(
f"Validation failed for {len(self.__malformed_records)} rows."
f"Validation failed for {len(self.__malformed_records)} rows.",
)

def generate_label_name(self, col_name: str, label_column: str = None):
def generate_label_name(self, col_name: str, label_column: str = None) -> str:
label_column = label_column or f"{self.config.task_name()}_task"
return f"{label_column}_{col_name}"

Expand Down
Loading
Loading