Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 65 additions & 57 deletions packages/hub/src/lib/commit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,13 @@ export type CommitParams = {
*/
fetch?: typeof fetch;
abortSignal?: AbortSignal;
// Credentials are optional due to custom fetch functions or cookie auth
/**
* @default true
*
* Use xet protocol: https://huggingface.co/blog/xet-on-the-hub to upload, rather than a basic S3 PUT
*/
useXet?: boolean;
// Credentials are optional due to custom fetch functions or cookie auth
} & Partial<CredentialsParams>;

export interface CommitOutput {
Expand Down Expand Up @@ -165,24 +170,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
const repoId = toRepoId(params.repo);
yield { event: "phase", phase: "preuploading" };

let useXet = params.useXet;
if (useXet) {
const info = await (params.fetch ?? fetch)(
`${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}?expand[]=xetEnabled`,
{
headers: {
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
},
}
);

if (!info.ok) {
throw await createApiError(info);
}

const data = await info.json();
useXet = !!data.xetEnabled;
}
let useXet = params.useXet ?? true;

const lfsShas = new Map<string, string | null>();

Expand All @@ -206,10 +194,6 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
const allOperations = (
await Promise.all(
params.operations.map(async (operation) => {
if (operation.operation === "edit" && !useXet) {
throw new Error("Edit operation is not supported when Xet is disabled");
}

if (operation.operation === "edit") {
// Convert EditFile operation to a file operation with SplicedBlob
const splicedBlob = SplicedBlob.create(
Expand Down Expand Up @@ -325,7 +309,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
const payload: ApiLfsBatchRequest = {
operation: "upload",
// multipart is a custom protocol for HF
transfers: ["basic", "multipart"],
transfers: ["basic", "multipart", ...(useXet ? ["xet" as const] : [])],
hash_algo: "sha_256",
...(!params.isPullRequest && {
ref: {
Expand Down Expand Up @@ -363,6 +347,12 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr

const shaToOperation = new Map(operations.map((op, i) => [shas[i], op]));

if (useXet && json.transfer !== "xet") {
useXet = false;
}
let xetRefreshWriteTokenUrl: string | undefined;
let xetSessionId: string | undefined;

if (useXet) {
// First get all the files that are already uploaded out of the way
for (const obj of json.objects) {
Expand All @@ -386,6 +376,17 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
progress: 1,
state: "uploading",
};
} else {
xetRefreshWriteTokenUrl = obj.actions.upload.href;
// Also, obj.actions.upload.header: {
// X-Xet-Cas-Url: string;
// X-Xet-Access-Token: string;
// X-Xet-Token-Expiration: string;
// X-Xet-Session-Id: string;
// }
const headers = new Headers(obj.actions.upload.header);
xetSessionId = headers.get("X-Xet-Session-Id") ?? undefined;
// todo: use other data, like x-xet-cas-url, ...
}
}
const source = (async function* () {
Expand All @@ -395,43 +396,50 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
continue;
}
abortSignal?.throwIfAborted();

yield { content: op.content, path: op.path, sha256: obj.oid };
}
})();
const sources = splitAsyncGenerator(source, 5);
yield* eventToGenerator((yieldCallback, returnCallback, rejectCallback) =>
Promise.all(
sources.map(async function (source) {
for await (const event of uploadShards(source, {
fetch: params.fetch,
accessToken,
hubUrl: params.hubUrl ?? HUB_URL,
repo: repoId,
// todo: maybe leave empty if PR?
rev: params.branch ?? "main",
isPullRequest: params.isPullRequest,
yieldCallback: (event) => yieldCallback({ ...event, state: "uploading" }),
})) {
if (event.event === "file") {
yieldCallback({
event: "fileProgress" as const,
path: event.path,
progress: 1,
state: "uploading" as const,
});
} else if (event.event === "fileProgress") {
yieldCallback({
event: "fileProgress" as const,
path: event.path,
progress: event.progress,
state: "uploading" as const,
});
if (xetRefreshWriteTokenUrl) {
const xetRefreshWriteTokenUrlFixed = xetRefreshWriteTokenUrl;
const sources = splitAsyncGenerator(source, 5);
yield* eventToGenerator((yieldCallback, returnCallback, rejectCallback) =>
Promise.all(
sources.map(async function (source) {
for await (const event of uploadShards(source, {
fetch: params.fetch,
accessToken,
hubUrl: params.hubUrl ?? HUB_URL,
repo: repoId,
xetRefreshWriteTokenUrl: xetRefreshWriteTokenUrlFixed,
xetSessionId,
// todo: maybe leave empty if PR?
rev: params.branch ?? "main",
isPullRequest: params.isPullRequest,
yieldCallback: (event) => yieldCallback({ ...event, state: "uploading" }),
})) {
if (event.event === "file") {
// No need: uploading xorbs already sent a fileProgress event with progress 1
// yieldCallback({
// event: "fileProgress" as const,
// path: event.path,
// progress: 1,
// state: "uploading" as const,
// });
} else if (event.event === "fileProgress") {
yieldCallback({
event: "fileProgress" as const,
path: event.path,
progress: event.progress,
state: "uploading" as const,
});
}
}
}
})
).then(() => returnCallback(undefined), rejectCallback)
);
})
).then(() => returnCallback(undefined), rejectCallback)
);
} else {
// No LFS file to upload
}
} else {
yield* eventToGenerator<CommitProgressEvent, void>((yieldCallback, returnCallback, rejectCallback) => {
return promisesQueueStreaming(
Expand Down
Loading