"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""

from __future__ import annotations
from mistralai.types import BaseModel, Nullable, OptionalNullable, UNSET, UNSET_SENTINEL
from pydantic import model_serializer
from typing import Optional
from typing_extensions import NotRequired, TypedDict


class TrainingParametersInTypedDict(TypedDict):
    r"""The fine-tuning hyperparameter settings used in a fine-tune job."""

    training_steps: NotRequired[Nullable[int]]
    r"""The number of training steps to perform. A training step refers to a single update of the model weights during the fine-tuning process. This update is typically calculated using a batch of samples from the training dataset."""
    learning_rate: NotRequired[float]
    r"""A parameter describing how much to adjust the pre-trained model's weights in response to the estimated error each time the weights are updated during the fine-tuning process."""
    weight_decay: NotRequired[Nullable[float]]
    r"""(Advanced Usage) Weight decay adds a term to the loss function that is proportional to the sum of the squared weights. This term reduces the magnitude of the weights and prevents them from growing too large."""
    warmup_fraction: NotRequired[Nullable[float]]
    r"""(Advanced Usage) A parameter that specifies the percentage of the total training steps at which the learning rate warm-up phase ends. During this phase, the learning rate gradually increases from a small value to the initial learning rate, helping to stabilize the training process and improve convergence. Similar to `pct_start` in [mistral-finetune](https://github.com/mistralai/mistral-finetune)"""
    epochs: NotRequired[Nullable[float]]
    fim_ratio: NotRequired[Nullable[float]]
    seq_len: NotRequired[Nullable[int]]


class TrainingParametersIn(BaseModel):
    r"""The fine-tuning hyperparameter settings used in a fine-tune job."""

    training_steps: OptionalNullable[int] = UNSET
    r"""The number of training steps to perform. A training step refers to a single update of the model weights during the fine-tuning process. This update is typically calculated using a batch of samples from the training dataset."""

    learning_rate: Optional[float] = 0.0001
    r"""A parameter describing how much to adjust the pre-trained model's weights in response to the estimated error each time the weights are updated during the fine-tuning process."""

    weight_decay: OptionalNullable[float] = UNSET
    r"""(Advanced Usage) Weight decay adds a term to the loss function that is proportional to the sum of the squared weights. This term reduces the magnitude of the weights and prevents them from growing too large."""

    warmup_fraction: OptionalNullable[float] = UNSET
    r"""(Advanced Usage) A parameter that specifies the percentage of the total training steps at which the learning rate warm-up phase ends. During this phase, the learning rate gradually increases from a small value to the initial learning rate, helping to stabilize the training process and improve convergence. Similar to `pct_start` in [mistral-finetune](https://github.com/mistralai/mistral-finetune)"""

    epochs: OptionalNullable[float] = UNSET

    fim_ratio: OptionalNullable[float] = UNSET

    seq_len: OptionalNullable[int] = UNSET

    @model_serializer(mode="wrap")
    def serialize_model(self, handler):
        optional_fields = [
            "training_steps",
            "learning_rate",
            "weight_decay",
            "warmup_fraction",
            "epochs",
            "fim_ratio",
            "seq_len",
        ]
        nullable_fields = [
            "training_steps",
            "weight_decay",
            "warmup_fraction",
            "epochs",
            "fim_ratio",
            "seq_len",
        ]
        null_default_fields = []

        serialized = handler(self)

        m = {}

        for n, f in self.model_fields.items():
            k = f.alias or n
            val = serialized.get(k)
            serialized.pop(k, None)

            optional_nullable = k in optional_fields and k in nullable_fields
            is_set = (
                self.__pydantic_fields_set__.intersection({n})
                or k in null_default_fields
            )  # pylint: disable=no-member

            if val is not None and val != UNSET_SENTINEL:
                m[k] = val
            elif val != UNSET_SENTINEL and (
                not k in optional_fields or (optional_nullable and is_set)
            ):
                m[k] = val

        return m
